Exemplo n.º 1
0
  def reset(self, output_dir, init_checkpoint=None):
    """Reset the model parameters.

    Restores the parameters from the given output_dir if a checkpoint exists,
    otherwise randomly initializes them.

    Does not re-jit the model.

    Args:
      output_dir: Output directory.
      init_checkpoint: Initial checkpoint (default $output_dir/model.pkl.gz)
    """
    self.close()
    self._output_dir = output_dir
    if output_dir is not None:
      tf.io.gfile.makedirs(output_dir)
    else:
      assert not self._should_save_checkpoints
      assert not self._should_write_summaries

    # Create summary writers and history.
    if self._should_write_summaries:
      self._train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, 'train'),
                                              enable=self._is_chief)
      self._eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, 'eval'),
                                             enable=self._is_chief)

    # Reset the train and eval streams.
    self._train_stream = _repeat_stream(self._inputs.train_stream,
                                        self._n_devices)
    # TODO(lukaszkaiser): add an option to evaluate exactly on the full eval
    #   set by adding a padding and stopping the stream when too large.
    self._eval_stream = _repeat_stream(
        self._inputs.eval_stream, self._n_devices)
    self._train_eval_stream = _repeat_stream(
        self._inputs.train_eval_stream, self._n_devices)

    # Restore the training state.
    if output_dir is not None:
      state = load_trainer_state(output_dir, self._model_with_loss,
                                 init_checkpoint)
    else:
      state = TrainerState(step=None, opt_state=None,
                           history=trax_history.History(), model_state=None)
    self._step = state.step or 0
    history = state.history
    self._history = history
    if state.opt_state:
      opt_state = state.opt_state
      model_state = state.model_state
    else:
      opt_state, model_state = self._new_opt_state_and_model_state()
      model_state = self._for_n_devices(model_state)
    self._opt_state = OptState(*self._for_n_devices(opt_state))
    self._model_state = model_state
    if not state.opt_state and self._should_save_checkpoints:
      self.save_state(keep=False)
Exemplo n.º 2
0
def load_trainer_state(output_dir, model, weights_file=None):
  """Returns a TrainerState instance loaded from the given `output_dir`."""
  if weights_file is None:
    weights_file = os.path.join(output_dir, 'model.pkl.gz')
    if not tf.io.gfile.exists(weights_file):
      return TrainerState(step=None, opt_state=None,
                          history=trax_history.History(), model_state=None)
  elif not tf.io.gfile.exists(weights_file):
    raise ValueError('File not found: %s' % weights_file)

  trainer_state_dict = unpickle_from_file(weights_file, gzip=True)
  trainer_state = trainer_state_from_dict(trainer_state_dict, model)
  log('Model loaded from %s at step %d' % (weights_file, trainer_state.step))
  logging.debug('From loaded model : history = %s', trainer_state.history)
  return trainer_state
Exemplo n.º 3
0
  def __init__(
      self,
      model,
      tasks,
      eval_model=None,
      eval_tasks=None,
      output_dir=None,
      checkpoint_at=None,
      permanent_checkpoint_at=None,
      eval_at=None,
      which_task=None,
      n_devices=None,
      random_seed=None,
      loss_chunk_size=0,
      use_memory_efficient_trainer=False,
      callbacks=None,
  ):
    """Configures a training ``Loop``, including a random initialization.

    Args:
      model: Trax layer, representing the core model to be trained. Loss
          functions and eval functions (a.k.a. metrics) are considered to be
          outside the core model, taking core model output and data labels as
          their two inputs.
      tasks: List of :py:class:`TrainTask` instances, which define the training
          data, loss function, and optimizer to be used in respective tasks in
          this training loop. It can also be a single :py:class:`TrainTask`
          instance which is treated in the same way as a singleton list.
      eval_model: Optional Trax layer, representing model used for evaluation,
          e.g., with dropout turned off. If ``None``, the training model (model)
          will be used.
      eval_tasks: List of :py:class:`EvalTask` instances which define how to
          evaluate the model: which validation data to use and which metrics to
          report. Evaluation on each of the tasks and will run and be reported
          separately which allows to score a model on different subtasks. This
          argument can also be ``None``, in which case no evals will be run, or
          a single :py:class:`EvalTask`, which wil be treated in the same way
          as a singleton list.
      output_dir: Path telling where to save outputs (evals and checkpoints).
          Can be ``None`` if both ``eval_task`` and ``checkpoint_at`` are
          ``None``.
      checkpoint_at: Function (integer --> boolean) telling, for step n, whether
          that step should have its checkpoint saved. If ``None``, the default
          is periodic checkpointing at ``task.n_steps_per_checkpoint``.
      permanent_checkpoint_at: Function (integer --> boolean) telling,
          for step n, whether that step should have its checkpoint saved
          permanently. If ``None``, the default is periodic checkpointing at
          ``task.n_steps_per_permanent_checkpoint``.
      eval_at: Function (integer --> boolean) that says, for training step n,
          whether that step should run evals. If ``None``, run when
          checkpointing.
      which_task: Function (integer --> integer) indicating which task should be
          used at which training step. Can be set to ``None`` in single-task
          training.
      n_devices: integer or ``None``, the number of devices for this
          computation.
      random_seed: the random seed to use; time/os dependent if ``None``
          (default).
      loss_chunk_size: int, if > 0 use chunks of this size to make loss
        computation more more memory-efficient.
      use_memory_efficient_trainer: whether to use a special memory-efficient
        trainer; if set to 2, the memory efficiency if very aggressive
      callbacks: List of subclasses of StepCallback to call on training
        steps.
    """
    self._is_chief, self._n_hosts, self._n_devices, self._rng = (
        init_host_and_devices(n_devices, random_seed))
    if use_memory_efficient_trainer:
      self._rng = tl.on_cpu(self._rng)

    # Handle single task case without lists too.
    if not isinstance(tasks, (list, tuple)):
      tasks = [tasks]

    if not tasks:
      raise ValueError('Must provide at least one training task.')
    if eval_tasks is None:
      eval_tasks = []
      eval_at = _never
    else:
      if not isinstance(eval_tasks, (list, tuple)):
        eval_tasks = [eval_tasks]

    self._tasks = tasks
    self._model = model
    self._eval_model = eval_model or model

    self._use_memory_efficient_trainer = use_memory_efficient_trainer
    self._loss_chunk_size = loss_chunk_size
    # TODO(lukaszkaiser): can we have different eval models and save memory?
    if use_memory_efficient_trainer:
      assert len(tasks) == 1, 'only single task supported for now'
      self._eval_model = model

    default_at = _at_step_1_and_every_nth_step(tasks[0].n_steps_per_checkpoint)
    permanent_default_at = _at_step_1_and_every_nth_step(
        tasks[0].n_steps_per_permanent_checkpoint)
    if output_dir is not None:
      self._output_dir = os.path.expanduser(output_dir)
      tf.io.gfile.makedirs(self._output_dir)
      inputs.load_data_counters(self._output_dir)
    else:
      self._output_dir = None

    # Prepare training components.
    self._step = 0
    self._history = trax_history.History()
    self._checkpoint_at = checkpoint_at or default_at
    self._permanent_checkpoint_at = (
        permanent_checkpoint_at or permanent_default_at)
    if which_task is None:
      # If which task is not passed, then we permute tasks one by one.
      # If len(tasks) = 1, then which_task is a constant function equal to 0.
      which_task = lambda n: n % len(tasks)
    self._which_task = which_task

    # Initialize using the given random seed.
    # NOTE: If random_seed is None then self._rng will be different on
    # different hosts, leading to different weights on the different hosts.
    self._batch_signature = shapes.signature(tasks[0].sample_batch)
    self._model.rng = self.new_rng()
    # In the memory-efficient case, we initialize in init_trainer.
    if not use_memory_efficient_trainer:
      if _is_uninitialized(self._model):
        self._model.init(self._batch_signature)
      self._eval_model.rng = self.new_rng()
      if _is_uninitialized(self._eval_model):
        self._eval_model.init(self._batch_signature)

    # To handle the above case (i.e. random_seed = None), we psum the weights
    # and state and average them.
    # NOTE: This adds time (how much?) so we prefer not to do it if it is
    # unnecessary, i.e. random_seed was set.
    # NOTE: Averaging the weights across devices can screw up the initial weight
    # statistics.
    # TODO(pkozakowski): Broadcast from one of the devices instead?
    # TODO(lukaszkaiser): make it work for the memory-efficient trainer too.
    if (random_seed is None and self._n_hosts > 1 and
        not use_memory_efficient_trainer):
      logging.info('Syncing weights/state across %d hosts.', self._n_hosts)
      self._sync_weights_and_state_across_hosts()

    # Create the optimizer for the training loss function.
    self._trainer_per_task = tuple(self._init_trainer(task) for task in tasks)
    self.load_checkpoint()

    # Prepare eval components.
    self._eval_at = eval_at or default_at
    self._eval_tasks = eval_tasks
    loss_names = [task.loss_name for task in self._tasks]
    metric_names = [
        name  # pylint: disable=g-complex-comprehension
        for eval_task in self._eval_tasks
        for name in eval_task.metric_names
    ]
    self._rjust_len = max(map(len, loss_names + metric_names))
    self._evaluator_per_task = tuple(
        self._init_evaluator(eval_task) for eval_task in self._eval_tasks)

    if self._output_dir is None:
      _log('Will not write evaluation metrics, because output_dir is None.')

    def task_output_dir(task_index, task_list):
      if self._output_dir is not None:
        if len(task_list) < 2:
          output_dir = self._output_dir
        else:
          output_dir = os.path.join(self._output_dir, str(task_index))
        tf.io.gfile.makedirs(output_dir)
        return output_dir
      else:
        return None
    self._output_dir_per_eval_task = [
        task_output_dir(i, eval_tasks) for i in range(len(eval_tasks))]
    self._output_dir_per_train_task = [
        task_output_dir(i, tasks) for i in range(len(tasks))]

    callbacks = callbacks or []
    self._callbacks = [
        callback_class(self) for callback_class in callbacks
    ]
Exemplo n.º 4
0
 def test_metrics_for_mode(self):
     history = trax_history.History()
     history.append('train', 'metric1', 1, 0.1)
     history.append('train', 'metric2', 2, 0.2)
     self.assertEqual(history.metrics_for_mode('train'),
                      ['metric1', 'metric2'])
Exemplo n.º 5
0
 def test_modes(self):
     history = trax_history.History()
     history.append('train', 'metric1', 1, 0.1)
     history.append('test', 'metric2', 2, 0.2)
     self.assertEqual(history.modes, ['test', 'train'])
Exemplo n.º 6
0
 def test_serializer_and_deserializer(self):
     history = trax_history.History()
     history.append('train', 'metric1', 1, 0.1)
     json_object = history.to_dict()
     history2 = trax_history.History.from_dict(json_object)
     self.assertEqual(history2.get('train', 'metric1'), [(1, 0.1)])
Exemplo n.º 7
0
 def test_unknown_metric(self):
     history = trax_history.History()
     history.append('train', 'metric1', 1, 0.1)
     self.assertEqual(history.get('train', 'unknown_metric'), [])