Пример #1
0
  def __init__(self, model, src_file=None, trg_file=None, dev_every=0,
               batcher=bare(SrcBatcher, batch_size=32), loss_calculator=None,
               run_for_epochs=None, lr_decay=1.0, lr_decay_times=3, patience=1,
               initial_patience=None, dev_tasks=None, restart_trainer=False,
               reload_command=None, name=None, sample_train_sents: Optional[int] = None,
               max_num_train_sents=None, max_src_len=None, max_trg_len=None):
    self.src_file = src_file
    self.trg_file = trg_file
    self.dev_tasks = dev_tasks

    if lr_decay > 1.0 or lr_decay <= 0.0:
      raise RuntimeError("illegal lr_decay, must satisfy: 0.0 < lr_decay <= 1.0")
    self.lr_decay = lr_decay
    self.patience = patience
    self.initial_patience = initial_patience
    self.lr_decay_times = lr_decay_times
    self.restart_trainer = restart_trainer
    self.run_for_epochs = run_for_epochs

    self.early_stopping_reached = False
    # training state
    self.training_state = TrainingState()

    self.reload_command = reload_command

    self.model = model
    self.loss_calculator = loss_calculator or LossCalculator(MLELoss())

    self.sample_train_sents = sample_train_sents
    self.max_num_train_sents = max_num_train_sents
    self.max_src_len = max_src_len
    self.max_trg_len = max_trg_len

    self.batcher = batcher
    self.logger = BatchLossTracker(self, dev_every, name)
Пример #2
0
class SimpleTrainingTask(TrainingTask, Serializable):
  """
  Args:
    model: a generator.GeneratorModel object
    src_file: The file for the source data.
    trg_file: The file for the target data.
    dev_every (int): dev checkpoints every n sentences (0 for only after epoch)
    batcher: Type of batcher
    loss_calculator:
    run_for_epochs (int): number of epochs (None for unlimited epochs)
    lr_decay (float):
    lr_decay_times (int):  Early stopping after decaying learning rate a certain number of times
    patience (int): apply LR decay after dev scores haven't improved over this many checkpoints
    initial_patience (int): if given, allows adjusting patience for the first LR decay
    dev_tasks: A list of tasks to run on the development set
    restart_trainer: Restart trainer (useful for Adam) and revert weights to best dev checkpoint when applying LR decay (https://arxiv.org/pdf/1706.09733.pdf)
    reload_command: Command to change the input data after each epoch.
                         --epoch EPOCH_NUM will be appended to the command.
                         To just reload the data after each epoch set the command to 'true'.
    sample_train_sents: If given, load a random subset of training sentences before each epoch. Useful when training data does not fit in memory.
    max_num_train_sents:
    max_src_len:
    max_trg_len:
    name: will be prepended to log outputs if given
  """
  yaml_tag = '!SimpleTrainingTask'

  @serializable_init
  def __init__(self, model, src_file=None, trg_file=None, dev_every=0,
               batcher=bare(SrcBatcher, batch_size=32), loss_calculator=None,
               run_for_epochs=None, lr_decay=1.0, lr_decay_times=3, patience=1,
               initial_patience=None, dev_tasks=None, restart_trainer=False,
               reload_command=None, name=None, sample_train_sents: Optional[int] = None,
               max_num_train_sents=None, max_src_len=None, max_trg_len=None):
    self.src_file = src_file
    self.trg_file = trg_file
    self.dev_tasks = dev_tasks

    if lr_decay > 1.0 or lr_decay <= 0.0:
      raise RuntimeError("illegal lr_decay, must satisfy: 0.0 < lr_decay <= 1.0")
    self.lr_decay = lr_decay
    self.patience = patience
    self.initial_patience = initial_patience
    self.lr_decay_times = lr_decay_times
    self.restart_trainer = restart_trainer
    self.run_for_epochs = run_for_epochs

    self.early_stopping_reached = False
    # training state
    self.training_state = TrainingState()

    self.reload_command = reload_command

    self.model = model
    self.loss_calculator = loss_calculator or LossCalculator(MLELoss())

    self.sample_train_sents = sample_train_sents
    self.max_num_train_sents = max_num_train_sents
    self.max_src_len = max_src_len
    self.max_trg_len = max_trg_len

    self.batcher = batcher
    self.logger = BatchLossTracker(self, dev_every, name)

  def _augment_data_initial(self):
    """
    Called before loading corpus for the first time, if reload_command is given
    """
    augment_command = self.reload_command
    logger.debug('initial augmentation')
    if self._augmentation_handle is None:
      # first run
      self._augmentation_handle = Popen(augment_command + " --epoch 0", shell=True)
      self._augmentation_handle.wait()

  def _augment_data_next_epoch(self):
    """
    This is run in the background if reload_command is given to prepare data for the next epoch
    """
    augment_command = self.reload_command
    if self._augmentation_handle is None:
      # first run
      self._augmentation_handle = Popen(augment_command + " --epoch %d" % self.training_state.epoch_num, shell=True)
      self._augmentation_handle.wait()

    self._augmentation_handle.poll()
    retcode = self._augmentation_handle.returncode
    if retcode is not None:
      if self.training_state.epoch_num > 0:
        logger.info('using reloaded data')
      # reload the data   
      self.src_data, self.trg_data, self.src_batches, self.trg_batches = \
          xnmt.input_reader.read_parallel_corpus(self.model.src_reader, self.model.trg_reader,
                                          self.src_file, self.trg_file,
                                          batcher=self.batcher, sample_sents=self.sample_train_sents,
                                          max_num_sents=self.max_num_train_sents,
                                          max_src_len=self.max_src_len, max_trg_len=self.max_trg_len)
      # restart data generation
      self._augmentation_handle = Popen(augment_command + " --epoch %d" % self.training_state.epoch_num, shell=True)
    else:
      logger.info('new data set is not ready yet, using data from last epoch.')

  @register_xnmt_event
  def new_epoch(self, training_task, num_sents):
    """
    New epoch event.

    Args:
      training_task: Indicates which training task is advancing to the next epoch.
      num_sents: Number of sentences in the upcoming epoch (may change between epochs)
    """
    pass

  def should_stop_training(self):
    """
    Signal stopping if self.early_stopping_reached is marked or we exhausted the number of requested epochs.
    """
    return self.early_stopping_reached \
      or self.run_for_epochs is not None and (self.training_state.epoch_num > self.run_for_epochs \
                                              or (self.training_state.epoch_num == self.run_for_epochs and
                                                  self.training_state.steps_into_epoch >= self.cur_num_minibatches() - 1))

  def cur_num_minibatches(self):
    """
    Current number of minibatches (may change between epochs, e.g. for randomizing batchers or if reload_command is given)
    """
    return len(self.src_batches)

  def cur_num_sentences(self):
    """
    Current number of parallel sentences (may change between epochs, e.g. if reload_command is given)
    """
    return len(self.src_data)

  def advance_epoch(self):
    """
    Shifts internal state to the next epoch, including data (re-)loading, batch re-packing and shuffling.
    """
    if self.reload_command is not None:
      if self.training_state.epoch_num==0:
        self._augmentation_handle = None
        self._augment_data_initial()
      else:
        self._augment_data_next_epoch()
    if self.training_state.epoch_num==0 or self.sample_train_sents:
      self.src_data, self.trg_data, self.src_batches, self.trg_batches = \
        xnmt.input_reader.read_parallel_corpus(self.model.src_reader, self.model.trg_reader,
                                               self.src_file, self.trg_file,
                                               batcher=self.batcher, sample_sents=self.sample_train_sents,
                                               max_num_sents=self.max_num_train_sents,
                                               max_src_len=self.max_src_len, max_trg_len=self.max_trg_len)
    self.training_state.epoch_seed = random.randint(1,2147483647)
    random.seed(self.training_state.epoch_seed)
    np.random.seed(self.training_state.epoch_seed)
    self.src_batches, self.trg_batches = \
      self.batcher.pack(self.src_data, self.trg_data)
    self.training_state.epoch_num += 1
    self.training_state.steps_into_epoch = 0
    self.minibatch_order = list(range(0, self.cur_num_minibatches()))
    np.random.shuffle(self.minibatch_order)
    self.new_epoch(training_task=self, num_sents=self.cur_num_sentences())

  def next_minibatch(self):
    """
    Infinitely loops over training minibatches and calls advance_epoch() after every complete sweep over the corpus.

    Returns:
      Generator yielding (src_batch,trg_batch) tuples
    """
    while True:
      self.advance_epoch()
      for batch_num in self.minibatch_order:
        src = self.src_batches[batch_num]
        trg = self.trg_batches[batch_num]
        yield src, trg
        self.training_state.steps_into_epoch += 1

  def training_step(self, src, trg):
    """
    Performs forward pass, backward pass, parameter update for the given minibatch
    """
    loss_builder = LossBuilder()
    standard_loss = self.model.calc_loss(src, trg, self.loss_calculator)
    additional_loss = self.model.calc_additional_loss(standard_loss)
    loss_builder.add_loss("standard_loss", standard_loss)
    loss_builder.add_loss("additional_loss", additional_loss)

    loss_value = loss_builder.compute()
    self.logger.update_epoch_loss(src, trg, loss_builder.get_loss_stats())
    self.logger.report_train_process()

    return loss_value

  def checkpoint_needed(self):
    return self.logger.should_report_dev()

  def checkpoint(self, control_learning_schedule=True):
    """
    Performs a dev checkpoint

    Args:
      control_learning_schedule: If False, only evaluate dev data.
                                      If True, also perform model saving, LR decay etc. if needed.
    Returns:
      True if the model needs saving, False otherwise
    """
    ret = False
    self.logger.new_dev()

    # Perform evaluation
    if self.dev_tasks and len(self.dev_tasks) > 0:
      dev_scores = []
      for dev_task in self.dev_tasks:
        dev_score, dev_word_cnt = dev_task.eval()
        if type(dev_score) == list:
          dev_scores.extend(dev_score)
        else:
          dev_scores.append(dev_score)
      # TODO: This is passing "1" for the number of words, as this is not implemented yet
      self.logger.set_dev_score(dev_word_cnt, dev_scores[0])
      for dev_score in dev_scores[1:]:
        self.logger.report_auxiliary_score(dev_score)

    # Control the learning schedule
    if control_learning_schedule:
      logger.info("> Checkpoint")
      # Write out the model if it's the best one
      if self.logger.report_dev_and_check_model():
        ret = True
        self.training_state.cur_attempt = 0
      else:
        # otherwise: learning rate decay / early stopping
        self.training_state.cur_attempt += 1
        if self.lr_decay < 1.0:
          should_decay = False
          if (self.initial_patience is None or self.training_state.num_times_lr_decayed>0) \
                  and self.training_state.cur_attempt >= self.patience:
            should_decay = True
          if self.initial_patience is not None and self.training_state.num_times_lr_decayed==0 \
                  and self.training_state.cur_attempt >= self.initial_patience:
            should_decay = True
          if should_decay:
            self.training_state.num_times_lr_decayed += 1
            if self.training_state.num_times_lr_decayed > self.lr_decay_times:
              logger.info('  Early stopping')
              self.early_stopping_reached = True
            else:
              self.training_state.cur_attempt = 0
              self.trainer.learning_rate *= self.lr_decay
              logger.info('  new learning rate: %s' % self.trainer.learning_rate)
              if self.restart_trainer:
                logger.info('  restarting trainer and reverting learned weights to best checkpoint..')
                self.trainer.restart()
                ParamManager.param_col.revert_to_best_model()

    return ret
Пример #3
0
class SimpleTrainingTask(TrainingTask, Serializable):
    yaml_tag = u'!SimpleTrainingTask'

    def __init__(self,
                 yaml_context,
                 corpus_parser,
                 model,
                 glob={},
                 dev_every=0,
                 batcher=None,
                 loss_calculator=None,
                 pretrained_model_file="",
                 src_format="text",
                 run_for_epochs=None,
                 lr_decay=1.0,
                 lr_decay_times=3,
                 patience=1,
                 initial_patience=None,
                 dev_metrics="",
                 schedule_metric="loss",
                 restart_trainer=False,
                 reload_command=None,
                 name=None,
                 inference=None):
        """
    :param yaml_context:
    :param corpus_parser: an input.InputReader object
    :param model: a generator.GeneratorModel object
    :param dev_every (int): dev checkpoints every n sentences (0 for only after epoch)
    :param batcher: Type of batcher. Defaults to SrcBatcher of batch size 32.
    :param loss_calculator:
    :param pretrained_model_file: Path of pre-trained model file
    :param src_format: Format of input data: text/contvec
    :param lr_decay (float):
    :param lr_decay_times (int):  Early stopping after decaying learning rate a certain number of times
    :param patience (int): apply LR decay after dev scores haven't improved over this many checkpoints
    :param initial_patience (int): if given, allows adjusting patience for the first LR decay
    :param dev_metrics: Comma-separated list of evaluation metrics (bleu/wer/cer)
    :param schedule_metric: determine learning schedule based on this dev_metric (loss/bleu/wer/cer)
    :param restart_trainer: Restart trainer (useful for Adam) and revert weights to best dev checkpoint when applying LR decay (https://arxiv.org/pdf/1706.09733.pdf)
    :param reload_command: Command to change the input data after each epoch.
                           --epoch EPOCH_NUM will be appended to the command.
                           To just reload the data after each epoch set the command to 'true'.
    :param name: will be prepended to log outputs if given
    :param inference: used for inference during dev checkpoints if dev_metrics are specified
    """
        assert yaml_context is not None
        self.yaml_context = yaml_context
        self.model_file = self.yaml_context.dynet_param_collection.model_file
        self.yaml_serializer = YamlSerializer()

        if lr_decay > 1.0 or lr_decay <= 0.0:
            raise RuntimeError(
                "illegal lr_decay, must satisfy: 0.0 < lr_decay <= 1.0")
        self.lr_decay = lr_decay
        self.patience = patience
        self.initial_patience = initial_patience
        self.lr_decay_times = lr_decay_times
        self.restart_trainer = restart_trainer
        self.run_for_epochs = run_for_epochs

        self.early_stopping_reached = False
        # training state
        self.training_state = TrainingState()

        self.evaluators = [
            s.lower() for s in dev_metrics.split(",") if s.strip() != ""
        ]
        if schedule_metric.lower() not in self.evaluators:
            self.evaluators.append(schedule_metric.lower())
        if "loss" not in self.evaluators: self.evaluators.append("loss")
        if dev_metrics:
            self.inference = inference or SimpleInference()

        self.reload_command = reload_command
        if reload_command is not None:
            self._augmentation_handle = None
            self._augment_data_initial()

        self.model = model
        self.corpus_parser = corpus_parser
        self.loss_calculator = loss_calculator or LossCalculator(MLELoss())
        self.pretrained_model_file = pretrained_model_file
        if self.pretrained_model_file:
            self.yaml_context.dynet_param_collection.load_from_data_file(
                self.pretrained_model_file + '.data')

        self.batcher = batcher or SrcBatcher(32)
        if src_format == "contvec":
            self.batcher.pad_token = np.zeros(self.model.src_embedder.emb_dim)
        self.pack_batches()
        self.logger = BatchLossTracker(self, dev_every, name)

        self.schedule_metric = schedule_metric.lower()

    def dependent_init_params(self, initialized_subcomponents):
        """
    Overwrite Serializable.dependent_init_params() to realize sharing of vocab size between embedders and corpus parsers
    """
        return [
            DependentInitParam(param_descr="model.src_embedder.vocab_size",
                               value_fct=lambda: initialized_subcomponents[
                                   "corpus_parser"].src_reader.vocab_size()),
            DependentInitParam(param_descr="model.decoder.vocab_size",
                               value_fct=lambda: initialized_subcomponents[
                                   "corpus_parser"].trg_reader.vocab_size()),
            DependentInitParam(param_descr="model.trg_embedder.vocab_size",
                               value_fct=lambda: initialized_subcomponents[
                                   "corpus_parser"].trg_reader.vocab_size()),
            DependentInitParam(param_descr="model.src_embedder.vocab",
                               value_fct=lambda: initialized_subcomponents[
                                   "corpus_parser"].src_reader.vocab),
            DependentInitParam(param_descr="model.trg_embedder.vocab",
                               value_fct=lambda: initialized_subcomponents[
                                   "corpus_parser"].trg_reader.vocab)
        ]

    def pack_batches(self):
        """
    Packs src/trg examples into batches, possibly randomized. No shuffling performed here.
    """
        self.train_src, self.train_trg = \
          self.batcher.pack(self.corpus_parser.get_training_corpus().train_src_data, self.corpus_parser.get_training_corpus().train_trg_data)
        self.dev_src, self.dev_trg = \
          self.batcher.pack(self.corpus_parser.get_training_corpus().dev_src_data, self.corpus_parser.get_training_corpus().dev_trg_data)

    def _augment_data_initial(self):
        """
    Called before loading corpus for the first time, if reload_command is given
    """
        augment_command = self.reload_command
        print('initial augmentation')
        if self._augmentation_handle is None:
            # first run
            self._augmentation_handle = Popen(augment_command + " --epoch 0",
                                              shell=True)
            self._augmentation_handle.wait()

    def _augment_data_next_epoch(self):
        """
    This is run in the background if reload_command is given to prepare data for the next epoch
    """
        augment_command = self.reload_command
        if self._augmentation_handle is None:
            # first run
            self._augmentation_handle = Popen(
                augment_command +
                " --epoch %d" % self.training_state.epoch_num,
                shell=True)
            self._augmentation_handle.wait()

        self._augmentation_handle.poll()
        retcode = self._augmentation_handle.returncode
        if retcode is not None:
            if self.training_state.epoch_num > 0:
                print('using reloaded data')
            # reload the data
            self.corpus_parser._read_training_corpus(
                self.corpus_parser.training_corpus)  # TODO: fix
            # restart data generation
            self._augmentation_handle = Popen(
                augment_command +
                " --epoch %d" % self.training_state.epoch_num,
                shell=True)
        else:
            print('new data set is not ready yet, using data from last epoch.')

    @register_xnmt_event
    def new_epoch(self, training_regimen, num_sents):
        """
    New epoch event.
    :param training_regimen: Indicates which training regimen is advancing to the next epoch.
    :param num_sents: Number of sentences in the upcoming epoch (may change between epochs)
    """
        pass

    def should_stop_training(self):
        """
    Signal stopping if self.early_stopping_reached is marked or we exhausted the number of requested epochs.
    """
        return self.early_stopping_reached \
          or self.training_state.epoch_num > self.run_for_epochs \
          or (self.training_state.epoch_num == self.run_for_epochs and self.training_state.steps_into_epoch >= self.cur_num_minibatches()-1)

    def cur_num_minibatches(self):
        """
    Current number of minibatches (may change between epochs, e.g. for randomizing batchers or if reload_command is given)
    """
        return len(self.train_src)

    def cur_num_sentences(self):
        """
    Current number of parallel sentences (may change between epochs, e.g. if reload_command is given)
    """
        return len(self.corpus_parser.training_corpus.train_src_data)

    def advance_epoch(self):
        """
    Shifts internal state to the next epoch, including batch re-packing and shuffling.
    """
        if self.reload_command is not None:
            self._augment_data_next_epoch()
        self.training_state.epoch_seed = random.randint(1, 2147483647)
        random.seed(self.training_state.epoch_seed)
        np.random.seed(self.training_state.epoch_seed)
        self.pack_batches()
        self.training_state.epoch_num += 1
        self.training_state.steps_into_epoch = 0
        self.minibatch_order = list(range(0, self.cur_num_minibatches()))
        np.random.shuffle(self.minibatch_order)
        self.new_epoch(training_regimen=self,
                       num_sents=self.cur_num_sentences())

    def next_minibatch(self):
        """
    Infinitely loops over training minibatches and calls advance_epoch() after every complete sweep over the corpus.
    :returns: Generator yielding (src_batch,trg_batch) tuples 
    """
        while True:
            self.advance_epoch()
            for batch_num in self.minibatch_order:
                src = self.train_src[batch_num]
                trg = self.train_trg[batch_num]
                yield src, trg
                self.training_state.steps_into_epoch += 1

    def training_step(self, src, trg):
        """
    Performs forward pass, backward pass, parameter update for the given minibatch
    """
        loss_builder = LossBuilder()
        standard_loss = self.model.calc_loss(src, trg, self.loss_calculator)
        if standard_loss.__class__ == LossBuilder:
            loss = None
            for loss_name, loss_expr in standard_loss.loss_nodes:
                loss_builder.add_loss(loss_name, loss_expr)
                loss = loss_expr if not loss else loss + loss_expr
            standard_loss = loss

        else:
            loss_builder.add_loss("loss", standard_loss)

        additional_loss = self.model.calc_additional_loss(
            dy.nobackprop(-standard_loss))
        if additional_loss != None:
            loss_builder.add_loss("additional_loss", additional_loss)

        loss_value = loss_builder.compute()
        self.logger.update_epoch_loss(src, trg, loss_builder)
        self.logger.report_train_process()

        return loss_value

    def checkpoint_needed(self):
        return self.logger.should_report_dev()

    def checkpoint(self,
                   control_learning_schedule=True,
                   out_ext=".dev_hyp",
                   ref_ext=".dev_ref",
                   encoding='utf-8'):
        """
    Performs a dev checkpoint
    :param control_learning_schedule: If False, only evaluate dev data.
                                      If True, also perform model saving, LR decay etc. if needed.
    :param out_ext:
    :param ref_ext:
    :param encoding:
    :returns: True if the model needs saving, False otherwise
    """
        ret = False
        self.logger.new_dev()
        trg_words_cnt, loss_score = self.compute_dev_loss(
        )  # forced decoding loss

        eval_scores = {"loss": loss_score}
        if len(list(filter(lambda e: e != "loss", self.evaluators))) > 0:
            trg_file = None
            if self.model_file:
                evaluate_args = {}
                out_file = self.model_file + out_ext
                out_file_ref = self.model_file + ref_ext
                trg_file = out_file
                evaluate_args["hyp_file"] = out_file
                evaluate_args["ref_file"] = out_file_ref
            # Decoding + post_processing
            self.inference(corpus_parser=self.corpus_parser,
                           generator=self.model,
                           batcher=self.batcher,
                           src_file=self.corpus_parser.training_corpus.dev_src,
                           trg_file=trg_file,
                           candidate_id_file=self.corpus_parser.
                           training_corpus.dev_id_file)
            output_processor = self.inference.get_output_processor(
            )  # TODO: hack, refactor
            # Copy Trg to Ref
            processed = []
            with io.open(self.corpus_parser.training_corpus.dev_trg,
                         encoding=encoding) as fin:
                for line in fin:
                    processed.append(
                        output_processor.words_to_string(line.strip().split())
                        + u"\n")
            with io.open(out_file_ref, 'wt', encoding=encoding) as fout:
                for line in processed:
                    fout.write(line)
            # Evaluation
            for evaluator in self.evaluators:
                if evaluator == "loss": continue
                evaluate_args["evaluator"] = evaluator
                eval_score = xnmt.xnmt_evaluate.xnmt_evaluate(**evaluate_args)
                eval_scores[evaluator] = eval_score
        # Logging
        if self.schedule_metric == "loss":
            self.logger.set_dev_score(trg_words_cnt, loss_score)
        else:
            self.logger.set_dev_score(trg_words_cnt,
                                      eval_scores[self.schedule_metric])

        # print previously computed metrics
        for metric in self.evaluators:
            if metric != self.schedule_metric:
                self.logger.report_auxiliary_score(eval_scores[metric])

        if control_learning_schedule:
            print("> Checkpoint")
            # Write out the model if it's the best one
            if self.logger.report_dev_and_check_model(self.model_file):
                if self.model_file is not None:
                    ret = True
                self.training_state.cur_attempt = 0
            else:
                # otherwise: learning rate decay / early stopping
                self.training_state.cur_attempt += 1
                if self.lr_decay < 1.0:
                    should_decay = False
                    if (self.initial_patience is None
                            or self.training_state.num_times_lr_decayed > 0
                        ) and self.training_state.cur_attempt >= self.patience:
                        should_decay = True
                    if self.initial_patience is not None and self.training_state.num_times_lr_decayed == 0 and self.training_state.cur_attempt >= self.initial_patience:
                        should_decay = True
                    if should_decay:
                        self.training_state.num_times_lr_decayed += 1
                        if self.training_state.num_times_lr_decayed > self.lr_decay_times:
                            print('  Early stopping')
                            self.early_stopping_reached = True
                        else:
                            self.trainer.learning_rate *= self.lr_decay
                            print('  new learning rate: %s' %
                                  self.trainer.learning_rate)
                            if self.restart_trainer:
                                print(
                                    '  restarting trainer and reverting learned weights to best checkpoint..'
                                )
                                self.trainer.restart()
                                self.yaml_context.dynet_param_collection.revert_to_best_model(
                                )

        return ret

    def compute_dev_loss(self):
        loss_builder = LossBuilder()
        trg_words_cnt = 0
        for src, trg in zip(self.dev_src, self.dev_trg):
            dy.renew_cg()
            standard_loss = self.model.calc_loss(src, trg,
                                                 self.loss_calculator)
            loss_builder.add_loss("loss", standard_loss)
            trg_words_cnt += self.logger.count_trg_words(trg)
            loss_builder.compute()
        return trg_words_cnt, LossScore(loss_builder.sum() / trg_words_cnt)
Пример #4
0
    def __init__(self,
                 yaml_context,
                 corpus_parser,
                 model,
                 glob={},
                 dev_every=0,
                 batcher=None,
                 loss_calculator=None,
                 pretrained_model_file="",
                 src_format="text",
                 run_for_epochs=None,
                 lr_decay=1.0,
                 lr_decay_times=3,
                 patience=1,
                 initial_patience=None,
                 dev_metrics="",
                 schedule_metric="loss",
                 restart_trainer=False,
                 reload_command=None,
                 name=None,
                 inference=None):
        """
    :param yaml_context:
    :param corpus_parser: an input.InputReader object
    :param model: a generator.GeneratorModel object
    :param dev_every (int): dev checkpoints every n sentences (0 for only after epoch)
    :param batcher: Type of batcher. Defaults to SrcBatcher of batch size 32.
    :param loss_calculator:
    :param pretrained_model_file: Path of pre-trained model file
    :param src_format: Format of input data: text/contvec
    :param lr_decay (float):
    :param lr_decay_times (int):  Early stopping after decaying learning rate a certain number of times
    :param patience (int): apply LR decay after dev scores haven't improved over this many checkpoints
    :param initial_patience (int): if given, allows adjusting patience for the first LR decay
    :param dev_metrics: Comma-separated list of evaluation metrics (bleu/wer/cer)
    :param schedule_metric: determine learning schedule based on this dev_metric (loss/bleu/wer/cer)
    :param restart_trainer: Restart trainer (useful for Adam) and revert weights to best dev checkpoint when applying LR decay (https://arxiv.org/pdf/1706.09733.pdf)
    :param reload_command: Command to change the input data after each epoch.
                           --epoch EPOCH_NUM will be appended to the command.
                           To just reload the data after each epoch set the command to 'true'.
    :param name: will be prepended to log outputs if given
    :param inference: used for inference during dev checkpoints if dev_metrics are specified
    """
        assert yaml_context is not None
        self.yaml_context = yaml_context
        self.model_file = self.yaml_context.dynet_param_collection.model_file
        self.yaml_serializer = YamlSerializer()

        if lr_decay > 1.0 or lr_decay <= 0.0:
            raise RuntimeError(
                "illegal lr_decay, must satisfy: 0.0 < lr_decay <= 1.0")
        self.lr_decay = lr_decay
        self.patience = patience
        self.initial_patience = initial_patience
        self.lr_decay_times = lr_decay_times
        self.restart_trainer = restart_trainer
        self.run_for_epochs = run_for_epochs

        self.early_stopping_reached = False
        # training state
        self.training_state = TrainingState()

        self.evaluators = [
            s.lower() for s in dev_metrics.split(",") if s.strip() != ""
        ]
        if schedule_metric.lower() not in self.evaluators:
            self.evaluators.append(schedule_metric.lower())
        if "loss" not in self.evaluators: self.evaluators.append("loss")
        if dev_metrics:
            self.inference = inference or SimpleInference()

        self.reload_command = reload_command
        if reload_command is not None:
            self._augmentation_handle = None
            self._augment_data_initial()

        self.model = model
        self.corpus_parser = corpus_parser
        self.loss_calculator = loss_calculator or LossCalculator(MLELoss())
        self.pretrained_model_file = pretrained_model_file
        if self.pretrained_model_file:
            self.yaml_context.dynet_param_collection.load_from_data_file(
                self.pretrained_model_file + '.data')

        self.batcher = batcher or SrcBatcher(32)
        if src_format == "contvec":
            self.batcher.pad_token = np.zeros(self.model.src_embedder.emb_dim)
        self.pack_batches()
        self.logger = BatchLossTracker(self, dev_every, name)

        self.schedule_metric = schedule_metric.lower()
Пример #5
0
  def __init__(self, model, src_file=None, trg_file=None, dev_every=0,
               batcher=bare(SrcBatcher, batch_size=32), loss_calculator=None,
               run_for_epochs=None, lr_decay=1.0, lr_decay_times=3, patience=1,
               initial_patience=None, dev_tasks=None, restart_trainer=False,
               reload_command=None, name=None, sample_train_sents=None,
               max_num_train_sents=None, max_src_len=None, max_trg_len=None,
               exp_global=Ref(Path("exp_global"))):
    """
    Args:
      exp_global:
      model: a generator.GeneratorModel object
      src_file: The file for the source data.
      trg_file: The file for the target data.
      dev_every (int): dev checkpoints every n sentences (0 for only after epoch)
      batcher: Type of batcher
      loss_calculator:
      lr_decay (float):
      lr_decay_times (int):  Early stopping after decaying learning rate a certain number of times
      patience (int): apply LR decay after dev scores haven't improved over this many checkpoints
      initial_patience (int): if given, allows adjusting patience for the first LR decay
      dev_tasks: A list of tasks to run on the development set
      restart_trainer: Restart trainer (useful for Adam) and revert weights to best dev checkpoint when applying LR decay (https://arxiv.org/pdf/1706.09733.pdf)
      reload_command: Command to change the input data after each epoch.
                           --epoch EPOCH_NUM will be appended to the command.
                           To just reload the data after each epoch set the command to 'true'.
      sample_train_sents:
      max_num_train_sents:
      max_src_len:
      max_trg_len:
      name: will be prepended to log outputs if given
    """
    self.exp_global = exp_global
    self.model_file = self.exp_global.dynet_param_collection.model_file
    self.src_file = src_file
    self.trg_file = trg_file
    self.dev_tasks = dev_tasks

    if lr_decay > 1.0 or lr_decay <= 0.0:
      raise RuntimeError("illegal lr_decay, must satisfy: 0.0 < lr_decay <= 1.0")
    self.lr_decay = lr_decay
    self.patience = patience
    self.initial_patience = initial_patience
    self.lr_decay_times = lr_decay_times
    self.restart_trainer = restart_trainer
    self.run_for_epochs = run_for_epochs

    self.early_stopping_reached = False
    # training state
    self.training_state = TrainingState()

    self.reload_command = reload_command

    self.model = model
    self.loss_calculator = loss_calculator or LossCalculator(MLELoss())

    self.sample_train_sents = sample_train_sents
    self.max_num_train_sents = max_num_train_sents
    self.max_src_len = max_src_len
    self.max_trg_len = max_trg_len

    self.batcher = batcher
    self.logger = BatchLossTracker(self, dev_every, name)