示例#1
0
    def __init__(self,
                 model=Ref("model"),
                 src_file=None,
                 trg_file=None,
                 dev_every=0,
                 dev_zero=False,
                 batcher=bare(xnmt.batcher.SrcBatcher, batch_size=32),
                 loss_calculator=bare(MLELoss),
                 trainer=None,
                 run_for_epochs=None,
                 lr_decay=1.0,
                 lr_decay_times=3,
                 patience=1,
                 initial_patience=None,
                 dev_tasks=None,
                 dev_combinator=None,
                 restart_trainer: bool = False,
                 reload_command=None,
                 name="{EXP}",
                 sample_train_sents=None,
                 max_num_train_sents=None,
                 max_src_len=None,
                 max_trg_len=None,
                 commandline_args=Ref("exp_global.commandline_args",
                                      default=None)):

        super().__init__(model=model,
                         src_file=src_file,
                         trg_file=trg_file,
                         dev_every=dev_every,
                         batcher=batcher,
                         loss_calculator=loss_calculator,
                         run_for_epochs=run_for_epochs,
                         lr_decay=lr_decay,
                         lr_decay_times=lr_decay_times,
                         patience=patience,
                         initial_patience=initial_patience,
                         dev_tasks=dev_tasks,
                         dev_combinator=dev_combinator,
                         restart_trainer=restart_trainer,
                         reload_command=reload_command,
                         name=name,
                         sample_train_sents=sample_train_sents,
                         max_num_train_sents=max_num_train_sents,
                         max_src_len=max_src_len,
                         max_trg_len=max_trg_len)
        self.dev_zero = dev_zero
        self.trainer = trainer or xnmt.optimizer.SimpleSGDTrainer(e0=0.1)
        self.dynet_profiling = getattr(commandline_args, "dynet_profiling",
                                       0) if commandline_args else 0
        self.train_loss_tracker = TrainLossTracker(self)
示例#2
0
  def __init__(self, model: ConditionedModel = Ref("model"), src_file: Union[None, str, Sequence[str]] = None,
               trg_file: Optional[str] = None, dev_every: int = 0, dev_zero: bool = False,
               batcher: batcher.Batcher = bare(batcher.SrcBatcher, batch_size=32),
               loss_calculator: LossCalculator = bare(AutoRegressiveMLELoss),
               trainer: optimizer.XnmtOptimizer = bare(optimizer.SimpleSGDTrainer, e0=0.1),
               run_for_epochs: Optional[int] = None, lr_decay: float = 1.0, lr_decay_times: int = 3, patience: int = 1,
               initial_patience: Optional[int] = None, dev_tasks: Sequence[eval_task.EvalTask] = None,
               dev_combinator: Optional[str] = None, restart_trainer: bool = False,
               reload_command: Optional[str] = None, name: str = "{EXP}", sample_train_sents: Optional[int] = None,
               max_num_train_sents: Optional[int] = None, max_src_len: Optional[int] = None,
               max_trg_len: Optional[int] = None,
               loss_comb_method: str = Ref("exp_global.loss_comb_method", default="sum"),
               update_every: int = 1,
               commandline_args: dict = Ref("exp_global.commandline_args", default={})) -> None:

    super().__init__(model=model,
                     src_file=src_file,
                     trg_file=trg_file,
                     dev_every=dev_every,
                     batcher=batcher,
                     loss_calculator=loss_calculator,
                     run_for_epochs=run_for_epochs,
                     lr_decay=lr_decay,
                     lr_decay_times=lr_decay_times,
                     patience=patience,
                     initial_patience=initial_patience,
                     dev_tasks=dev_tasks,
                     dev_combinator=dev_combinator,
                     restart_trainer=restart_trainer,
                     reload_command=reload_command,
                     name=name,
                     sample_train_sents=sample_train_sents,
                     max_num_train_sents=max_num_train_sents,
                     max_src_len=max_src_len,
                     max_trg_len=max_trg_len)
    self.dev_zero = dev_zero
    self.trainer = trainer or optimizer.SimpleSGDTrainer(e0=0.1)
    self.dynet_profiling = commandline_args.get("dynet_profiling", 0) if commandline_args else 0
    self.train_loss_tracker = TrainLossTracker(self)
    self.loss_comb_method = loss_comb_method
    self.update_every = update_every
    self.num_updates_skipped = 0
示例#3
0
 def __init__(self,
              tasks: Sequence[training_task.TrainingTask],
              trainer: optimizer.XnmtOptimizer = bare(optimizer.SimpleSGDTrainer, e0=0.1),
              dev_zero: bool = False,
              loss_comb_method: str = Ref("exp_global.loss_comb_method", default="sum"),
              update_every: int = 1,
              commandline_args: dict = Ref("exp_global.commandline_args", default=None)) -> None:
   super().__init__(tasks=tasks, trainer=trainer, dev_zero=dev_zero, commandline_args=commandline_args,
                    update_every=update_every)
   self.train_loss_trackers = {task: TrainLossTracker(task) for task in tasks}
   self.loss_comb_method = loss_comb_method
示例#4
0
 def __init__(self,
              tasks,
              trainer=None,
              dev_zero=False,
              commandline_args=Ref("exp_global.commandline_args",
                                   default=None)):
     super().__init__(tasks=tasks,
                      trainer=trainer,
                      dev_zero=dev_zero,
                      commandline_args=commandline_args)
     self.train_loss_trackers = {
         task: TrainLossTracker(task)
         for task in tasks
     }
示例#5
0
 def __init__(self,
              tasks: Sequence[training_task.TrainingTask],
              trainer: optimizer.XnmtOptimizer = bare(optimizer.SimpleSGDTrainer,e0=0.1),
              dev_zero: bool = False,
              per_task_backward: bool = True,
              loss_comb_method: str = Ref("exp_global.loss_comb_method", default="sum"),
              update_every: int = 1,
              n_task_steps: Optional[Sequence[int]] = None,
              commandline_args: dict = Ref("exp_global.commandline_args", default=None)) -> None:
   super().__init__(tasks=tasks, trainer=trainer, dev_zero=dev_zero, update_every=update_every,
                    commandline_args=commandline_args)
   self.train_loss_trackers = {task : TrainLossTracker(task) for task in tasks}
   self.per_task_backward = per_task_backward
   self.loss_comb_method = loss_comb_method
   self.n_task_steps = n_task_steps or [1] * len(tasks)
   if len(self.n_task_steps) != len(tasks):
     raise ValueError(f"number of tasks and steps per task do not match: {len(tasks)} != {len(self.n_task_steps)}")
示例#6
0
 def __init__(self,
              tasks,
              task_weights=None,
              trainer=None,
              dev_zero=False,
              commandline_args=Ref("exp_global.commandline_args",
                                   default=None)):
     super().__init__(tasks=tasks,
                      trainer=trainer,
                      dev_zero=dev_zero,
                      commandline_args=commandline_args)
     self.task_weights = task_weights or [1. / len(tasks)] * len(tasks)
     if len(self.task_weights) != len(self.tasks):
         raise ValueError(
             f"number of tasks must match number of task weights; "
             f"found: {len(self.task_weights)} != {len(self.tasks)}")
     self.train_loss_trackers = {
         task: TrainLossTracker(task)
         for task in tasks
     }
示例#7
0
 def __init__(self,
              tasks: Sequence[training_task.TrainingTask],
              task_weights: Optional[Sequence[float]] = None,
              trainer: optimizer.XnmtOptimizer = bare(optimizer.SimpleSGDTrainer, e0=0.1),
              dev_zero: bool = False,
              loss_comb_method: str = Ref("exp_global.loss_comb_method", default="sum"),
              update_every_within: int = 1,
              update_every_across: int = 1,
              commandline_args=Ref("exp_global.commandline_args", default=None)) -> None:
   super().__init__(tasks=tasks, trainer=trainer, dev_zero=dev_zero, update_every=update_every_across,
                    commandline_args=commandline_args)
   if update_every_within!=1 and update_every_across!=1:
     raise ValueError("update_every_within and update_every_across cannot be mixed.")
   self.update_every_within = update_every_within
   self.task_weights = task_weights or [1./len(tasks)] * len(tasks)
   if len(self.task_weights) != len(self.tasks):
     raise ValueError(f"number of tasks must match number of task weights; "
                      f"found: {len(self.task_weights)} != {len(self.tasks)}")
   self.train_loss_trackers = {task: TrainLossTracker(task) for task in tasks}
   self.loss_comb_method = loss_comb_method
示例#8
0
class SimpleTrainingRegimen(SimpleTrainingTask, TrainingRegimen, Serializable):
    """
  Args:
    model (TrainableModel): the model
    src_file (str): the source training file
    trg_file (str): the target training file
    dev_every (int): dev checkpoints every n sentences (0 for only after epoch)
    dev_zero (bool): if True, add a checkpoint before training loop is entered (useful with pretrained models).
    batcher (Batcher): Type of batcher
    loss_calculator (LossCalculator): The method for calculating the loss.
    trainer (XnmtOptimizer): Trainer object, default is SGD with learning rate 0.1
    run_for_epochs (int):
    lr_decay (float):
    lr_decay_times (int):  Early stopping after decaying learning rate a certain number of times
    patience (int): apply LR decay after dev scores haven't improved over this many checkpoints
    initial_patience (int): if given, allows adjusting patience for the first LR decay
    dev_tasks (List[EvalTask]): A list of tasks to use during the development stage.
    dev_combinator: A formula to combine together development scores into a single score to
                    choose whether to perform learning rate decay, etc.
                    e.g. 'x[0]-x[1]' would say that the first dev task score minus the
                    second dev task score is our measure of how good we're doing. If not
                    specified, only the score from the first dev task will be used.
    restart_trainer (bool): Restart trainer (useful for Adam) and revert weights to best dev checkpoint when applying LR decay (https://arxiv.org/pdf/1706.09733.pdf)
    reload_command (str): Command to change the input data after each epoch.
                         --epoch EPOCH_NUM will be appended to the command.
                         To just reload the data after each epoch set the command to 'true'.
    name (str): will be prepended to log outputs if given
    sample_train_sents (int):
    max_num_train_sents (int):
    max_src_len (int):
    max_trg_len (int):
    commandline_args (Namespace):
  """
    yaml_tag = '!SimpleTrainingRegimen'

    @serializable_init
    def __init__(self,
                 model=Ref("model"),
                 src_file=None,
                 trg_file=None,
                 dev_every=0,
                 dev_zero=False,
                 batcher=bare(xnmt.batcher.SrcBatcher, batch_size=32),
                 loss_calculator=bare(MLELoss),
                 trainer=None,
                 run_for_epochs=None,
                 lr_decay=1.0,
                 lr_decay_times=3,
                 patience=1,
                 initial_patience=None,
                 dev_tasks=None,
                 dev_combinator=None,
                 restart_trainer: bool = False,
                 reload_command=None,
                 name="{EXP}",
                 sample_train_sents=None,
                 max_num_train_sents=None,
                 max_src_len=None,
                 max_trg_len=None,
                 commandline_args=Ref("exp_global.commandline_args",
                                      default=None)):

        super().__init__(model=model,
                         src_file=src_file,
                         trg_file=trg_file,
                         dev_every=dev_every,
                         batcher=batcher,
                         loss_calculator=loss_calculator,
                         run_for_epochs=run_for_epochs,
                         lr_decay=lr_decay,
                         lr_decay_times=lr_decay_times,
                         patience=patience,
                         initial_patience=initial_patience,
                         dev_tasks=dev_tasks,
                         dev_combinator=dev_combinator,
                         restart_trainer=restart_trainer,
                         reload_command=reload_command,
                         name=name,
                         sample_train_sents=sample_train_sents,
                         max_num_train_sents=max_num_train_sents,
                         max_src_len=max_src_len,
                         max_trg_len=max_trg_len)
        self.dev_zero = dev_zero
        self.trainer = trainer or xnmt.optimizer.SimpleSGDTrainer(e0=0.1)
        self.dynet_profiling = getattr(commandline_args, "dynet_profiling",
                                       0) if commandline_args else 0
        self.train_loss_tracker = TrainLossTracker(self)

    def run_training(self, save_fct, update_weights=True):
        """
    Main training loop (overwrites TrainingRegimen.run_training())
    """
        if self.run_for_epochs > 0:
            for src, trg in self.next_minibatch():
                if self.dev_zero:
                    self.checkpoint_and_save(save_fct)
                    self.dev_zero = False
                dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE,
                            check_validity=settings.CHECK_VALIDITY)
                with self.train_loss_tracker.time_tracker:
                    self.model.set_train(True)
                    loss_builder = self.training_step(src, trg)
                    loss = loss_builder.compute()
                    if update_weights:
                        self.update_weights(loss, self.trainer,
                                            self.dynet_profiling)
                self.train_loss_tracker.report(trg,
                                               loss_builder.get_loss_stats())
                if self.checkpoint_needed():
                    self.checkpoint_and_save(save_fct)
                if self.should_stop_training(): break

    def checkpoint_and_save(self, save_fct):
        should_save = self.checkpoint()
        if should_save:
            save_fct()
示例#9
0
class SimpleTrainingRegimen(training_task.SimpleTrainingTask, TrainingRegimen, Serializable):
  """
  Args:
    model: the model
    src_file: the source training file
    trg_file: the target training file
    dev_every: dev checkpoints every n sentences (0 for only after epoch)
    dev_zero: if True, add a checkpoint before training loop is entered (useful with pretrained models).
    batcher: Type of batcher
    loss_calculator: The method for calculating the loss.
    trainer: Trainer object, default is SGD with learning rate 0.1
    run_for_epochs:
    lr_decay:
    lr_decay_times:  Early stopping after decaying learning rate a certain number of times
    patience: apply LR decay after dev scores haven't improved over this many checkpoints
    initial_patience: if given, allows adjusting patience for the first LR decay
    dev_tasks: A list of tasks to use during the development stage.
    dev_combinator: A formula to combine together development scores into a single score to
                    choose whether to perform learning rate decay, etc.
                    e.g. 'x[0]-x[1]' would say that the first dev task score minus the
                    second dev task score is our measure of how good we're doing. If not
                    specified, only the score from the first dev task will be used.
    restart_trainer: Restart trainer (useful for Adam) and revert weights to best dev checkpoint when applying
                            LR decay (https://arxiv.org/pdf/1706.09733.pdf)
    reload_command: Command to change the input data after each epoch.
                         --epoch EPOCH_NUM will be appended to the command.
                         To just reload the data after each epoch set the command to ``True``.
    name: will be prepended to log outputs if given
    sample_train_sents:
    max_num_train_sents:
    max_src_len:
    max_trg_len:
    loss_comb_method: method for combining loss across batch elements (``sum`` or ``avg``).
    update_every: simulate large-batch training by accumulating gradients over several steps before updating parameters
    commandline_args:
  """
  yaml_tag = '!SimpleTrainingRegimen'

  @serializable_init
  def __init__(self, model: ConditionedModel = Ref("model"), src_file: Union[None, str, Sequence[str]] = None,
               trg_file: Optional[str] = None, dev_every: int = 0, dev_zero: bool = False,
               batcher: batcher.Batcher = bare(batcher.SrcBatcher, batch_size=32),
               loss_calculator: LossCalculator = bare(AutoRegressiveMLELoss),
               trainer: optimizer.XnmtOptimizer = bare(optimizer.SimpleSGDTrainer, e0=0.1),
               run_for_epochs: Optional[int] = None, lr_decay: float = 1.0, lr_decay_times: int = 3, patience: int = 1,
               initial_patience: Optional[int] = None, dev_tasks: Sequence[eval_task.EvalTask] = None,
               dev_combinator: Optional[str] = None, restart_trainer: bool = False,
               reload_command: Optional[str] = None, name: str = "{EXP}", sample_train_sents: Optional[int] = None,
               max_num_train_sents: Optional[int] = None, max_src_len: Optional[int] = None,
               max_trg_len: Optional[int] = None,
               loss_comb_method: str = Ref("exp_global.loss_comb_method", default="sum"),
               update_every: int = 1,
               commandline_args: dict = Ref("exp_global.commandline_args", default={})) -> None:

    super().__init__(model=model,
                     src_file=src_file,
                     trg_file=trg_file,
                     dev_every=dev_every,
                     batcher=batcher,
                     loss_calculator=loss_calculator,
                     run_for_epochs=run_for_epochs,
                     lr_decay=lr_decay,
                     lr_decay_times=lr_decay_times,
                     patience=patience,
                     initial_patience=initial_patience,
                     dev_tasks=dev_tasks,
                     dev_combinator=dev_combinator,
                     restart_trainer=restart_trainer,
                     reload_command=reload_command,
                     name=name,
                     sample_train_sents=sample_train_sents,
                     max_num_train_sents=max_num_train_sents,
                     max_src_len=max_src_len,
                     max_trg_len=max_trg_len)
    self.dev_zero = dev_zero
    self.trainer = trainer or optimizer.SimpleSGDTrainer(e0=0.1)
    self.dynet_profiling = commandline_args.get("dynet_profiling", 0) if commandline_args else 0
    self.train_loss_tracker = TrainLossTracker(self)
    self.loss_comb_method = loss_comb_method
    self.update_every = update_every
    self.num_updates_skipped = 0

  def run_training(self, save_fct):
    """
    Main training loop (overwrites TrainingRegimen.run_training())
    """
    if self.run_for_epochs > 0:
      for src, trg in self.next_minibatch():
        if self.dev_zero:
          self.checkpoint_and_save(save_fct)
          self.dev_zero = False
        with util.ReportOnException({"src": src, "trg": trg, "graph": dy.print_text_graphviz}):
          dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY)
          with self.train_loss_tracker.time_tracker:
            self.model.set_train(True)
            loss_builder = self.training_step(src, trg)
            loss = loss_builder.compute()
            self.backward(loss, self.dynet_profiling)
            self.update(self.trainer)
          self.train_loss_tracker.report(trg, loss_builder.get_factored_loss_val(comb_method=self.loss_comb_method))
        if self.checkpoint_needed():
          self.checkpoint_and_save(save_fct)
        if self.should_stop_training(): break

  def checkpoint_and_save(self, save_fct):
    should_save = self.checkpoint()
    if should_save:
      save_fct()

  def update(self, trainer: optimizer.XnmtOptimizer) -> None:
    self.num_updates_skipped += 1
    if self.num_updates_skipped == self.update_every:
      trainer.update()
      self.num_updates_skipped = 0
    else:
      assert 0 < self.num_updates_skipped < self.update_every