Exemplo n.º 1
0
  def __init__(self, loss_layer, optimizer, n_devices=None):
    self._loss_layer = loss_layer
    self._optimizer = optimizer
    self._n_devices = n_devices or fastmath.device_count()

    # optimizer slots and opt_params may need to be replicated
    self._slots, self._opt_params = tl.for_n_devices(
        (self._optimizer.slots, self._optimizer.opt_params), self._n_devices)

    # accelerated version of loss layer to replicate weights and state
    self._accelerated_loss_layer = tl.Accelerate(
        loss_layer, n_devices=n_devices)

    # Signature:
    # (batch, weights, state, rng) -> ((loss, state), gradients)
    self._forward_and_backward_fn = (
        fastmath.value_and_grad(
            loss_layer.pure_fn,
            argnums=1,  # arg1 of pure_fn: weights
            has_aux=True))  # return (loss, state), gradients

    # Signature:
    # (weights, slots), step, opt_params, batch, state, rng ->
    # (weights, slots), state, stats
    self._accelerated_update_fn = (
        _accelerate_update_fn(
            self._forward_and_backward_fn,
            self._optimizer,
            n_devices=self._n_devices,
            accelerate=True,
        )
    )
Exemplo n.º 2
0
    def __init__(self,
                 model_with_loss,
                 optimizer,
                 n_devices=None,
                 adasum=False):
        self._model_with_loss = model_with_loss
        self._optimizer = optimizer
        self._n_devices = n_devices or fastmath.local_device_count()
        self._adasum = adasum

        # optimizer slots and opt_params may need to be replicated
        self._slots, self._opt_params = tl.on_cpu(
            tl.for_n_devices(
                (self._optimizer.slots, self._optimizer.opt_params),
                self._n_devices))

        # accelerated version of model+loss to replicate weights and state
        self._accelerated_model_with_loss = tl.Accelerate(model_with_loss,
                                                          n_devices=n_devices)

        # Signature:
        # (batch, weights, state, rng) -> ((loss, state), gradients)
        self._forward_and_backward_fn = (
            fastmath.value_and_grad(
                model_with_loss.pure_fn,
                argnums=1,  # arg1 of pure_fn: weights
                has_aux=True))  # return (loss, state), gradients

        # Signature:
        # (weights, slots), step, opt_params, batch, state, rng ->
        # (weights, slots), state, stats
        self._accelerated_update_fn = (_accelerate_update_fn(
            self._forward_and_backward_fn,
            self._optimizer,
            n_devices=self._n_devices,
            accelerate=True,
            adasum=self._adasum))
Exemplo n.º 3
0
    def __init__(self,
                 model,
                 task,
                 eval_model=None,
                 eval_task=None,
                 output_dir=None,
                 checkpoint_at=None,
                 eval_at=None):
        """Configures a training `Loop`, including a random initialization.

    Args:
      model: Trax layer, representing the core model to be trained. Loss
          functions and eval functions (a.k.a. metrics) are considered to be
          outside the core model, taking core model output and data labels as
          their two inputs.
      task: TrainTask instance, which defines the training data, loss function,
          and optimizer to be used in this training loop.
      eval_model: Optional Trax layer, representing model used for evaluation,
        e.g., with dropout turned off. If None, the training model (model)
        will be used.
      eval_task: EvalTask instance or None. If None, don't do any evals.
      output_dir: Path telling where to save outputs (evals and checkpoints).
          Can be None if both `eval_task` and `checkpoint_at` are None.
      checkpoint_at: Function (integer --> boolean) telling, for step n, whether
          that step should have its checkpoint saved. If None, the default is
          periodic checkpointing at `task.n_steps_per_checkpoint`.
      eval_at: Function (integer --> boolean) that says, for training step n,
          whether that step should run evals. If None, run when checkpointing.
    """
        self._task = task
        self._model = model
        self._eval_model = eval_model or model
        default_at = (_at_step_1_and_every_nth_step(
            self._task.n_steps_per_checkpoint))
        if output_dir is not None:
            self._output_dir = os.path.expanduser(output_dir)
            tf.io.gfile.makedirs(self._output_dir)
        else:
            self._output_dir = None

        # Prepare training components.
        self._step = 0
        self._checkpoint_at = checkpoint_at or default_at
        self._model_in_training = tl.Serial(self._model, self._task.loss_layer)
        self._batch_signature = shapes.signature(self._task.sample_batch)
        self._eval_model.init(self._batch_signature)
        self._model_in_training.init(self._batch_signature)
        self._task.optimizer.tree_init(self._model_in_training.weights)
        self._forward_and_backward_fn = (
            fastmath.jit(
                fastmath.value_and_grad(
                    self._model_in_training.pure_fn,
                    argnums=1,  # arg1 of pure_fn: weights
                    has_aux=True)))  # return (loss, state), gradients

        # Prepare eval components.
        if eval_task is None:
            self._eval_at = _never
        else:
            self._eval_task = eval_task
            self._eval_at = eval_at or default_at
            metric_name_lengths = [
                len(name) for name in self._eval_task.metric_names
            ]
            self._rjust_len = max([len(self._task.loss_layer.name)] +
                                  metric_name_lengths)
            model_with_metrics = (_model_with_metrics(self._eval_model,
                                                      self._eval_task))
            self._eval_weights = model_with_metrics.weights[
                1]  # just the eval part
            self._eval_state = model_with_metrics.state[
                1]  # just the eval part
            self._metrics_fn = fastmath.jit(model_with_metrics.pure_fn)
            if self._output_dir is None:
                _log(
                    'Will not write evaluation metrics, because output_dir is None.'
                )
Exemplo n.º 4
0
  def __init__(self, model, tasks, eval_model=None, eval_tasks=None,
               output_dir=None, checkpoint_at=None, eval_at=None,
               n_devices=None, random_seed=None):
    """Configures a training `Loop`, including a random initialization.

    Args:
      model: Trax layer, representing the core model to be trained. Loss
          functions and eval functions (a.k.a. metrics) are considered to be
          outside the core model, taking core model output and data labels as
          their two inputs.
      tasks: List of TrainTask instances, which define the training data, loss
          function, and optimizer to be used in respective tasks in this
          training loop.
      eval_model: Optional Trax layer, representing model used for evaluation,
        e.g., with dropout turned off. If None, the training model (model)
        will be used.
      eval_tasks: List of EvalTask instances or None. If None, don't do any
        evals.
      output_dir: Path telling where to save outputs (evals and checkpoints).
          Can be None if both `eval_task` and `checkpoint_at` are None.
      checkpoint_at: Function (integer --> boolean) telling, for step n, whether
          that step should have its checkpoint saved. If None, the default is
          periodic checkpointing at `task.n_steps_per_checkpoint`.
      eval_at: Function (integer --> boolean) that says, for training step n,
          whether that step should run evals. If None, run when checkpointing.
      n_devices: integer or None, the number of devices for this computation.
      random_seed: the random seed to use; time/os dependent if None (default).
    """
    self._is_chief, self._n_hosts, self._n_devices, self._rng = (
        init_host_and_devices(n_devices, random_seed))

    # Handle single task case without lists too.
    if not isinstance(tasks, (list, tuple)):
      tasks = [tasks]

    assert len(tasks) == 1, 'Multitask training not supported yet.'
    task = tasks[0]
    if eval_tasks is None:
      eval_task = None
    else:
      assert len(eval_tasks) == 1, 'Multitask training not supported yet.'
      eval_task = eval_tasks[0]

    self._task = task
    self._model = model
    self._eval_model = eval_model or model
    default_at = (
        _at_step_1_and_every_nth_step(self._task.n_steps_per_checkpoint))
    if output_dir is not None:
      self._output_dir = os.path.expanduser(output_dir)
      tf.io.gfile.makedirs(self._output_dir)
    else:
      self._output_dir = None

    # Prepare training components.
    self._step = 0
    self._checkpoint_at = checkpoint_at or default_at
    self._batch_signature = shapes.signature(self._task.sample_batch)
    self._model_in_training = tl.Serial(self._model, self._task.loss_layer)

    # Initialize using the given random seed.
    # NOTE: If `random_seed` is `None` then `self._rng` will be different on
    # different hosts, leading to different weights on the different hosts.
    self._model_in_training.rng = self.new_rng()
    self._model_in_training.init(self._batch_signature)
    self._eval_model.rng = self.new_rng()
    self._eval_model.init(self._batch_signature)

    # To handle the above case (i.e. random_seed = None), we psum the weights
    # and state and average them.
    # NOTE: This adds time (how much?) so we prefer not to do it if it is
    # unnecessary, i.e. random_seed was set.
    if random_seed is None and self._n_hosts > 1:
      logging.info('Syncing weights/state across %d hosts.', self._n_hosts)
      self._sync_weights_and_state_across_hosts()

    self._task.optimizer.tree_init(self._model_in_training.weights)

    # Signature:
    # (batch, weights, state, rng) -> ((loss, state), gradients)
    self._forward_and_backward_fn = (
        fastmath.value_and_grad(
            self._model_in_training.pure_fn,
            argnums=1,  # arg1 of pure_fn: weights
            has_aux=True))  # return (loss, state), gradients

    # Signature:
    # (weights, slots), step, opt_params, batch, state, rng ->
    # (weights, slots), state, stats
    self._accelerated_update_fn = (
        _accelerate_update_fn(
            self._forward_and_backward_fn,
            self._task.optimizer,
            n_devices=self.n_devices,
            accelerate=True,
        )
    )

    # Restore from checkpoint if there is one.
    self.load_checkpoint()

    # Prepare eval components.
    if eval_task is None:
      self._eval_at = _never
    else:
      self._eval_task = eval_task
      self._eval_at = eval_at or default_at
      metric_name_lengths = [len(name) for name in self._eval_task.metric_names]
      self._rjust_len = max(
          [len(self._task.loss_layer.name)] + metric_name_lengths)
      model_with_metrics = (
          _model_with_metrics(self._eval_model, self._eval_task))
      # Keep self._eval_{weights/state} replicated.
      self._eval_weights = self._for_n_devices(
          model_with_metrics.weights[1])  # just the eval part
      self._eval_state = self._for_n_devices(
          model_with_metrics.state[1])  # just the eval part
      self._metrics_fn = _accelerate_model_with_metrics(
          model_with_metrics, self.n_devices)
      if self._output_dir is None:
        _log('Will not write evaluation metrics, because output_dir is None.')
Exemplo n.º 5
0
    def __init__(self,
                 model,
                 task,
                 eval_model=None,
                 eval_task=None,
                 output_dir=None,
                 checkpoint_at=None,
                 eval_at=None):
        """Configures a training `Loop`, including a random initialization.

    Args:
      model: Trax layer, representing the core model to be trained. Loss
          functions and eval functions (a.k.a. metrics) are considered to be
          outside the core model, taking core model output and data labels as
          their two inputs.
      task: TrainTask instance, which defines the training data, loss function,
          and optimizer to be used in this training loop.
      eval_model: Optional Trax layer, representing model used for evaluation,
        e.g., with dropout turned off. If None, the training model (model)
        will be used.
      eval_task: EvalTask instance or None. If None, don't do any evals.
      output_dir: Path telling where to save outputs (evals and checkpoints).
          Can be None if both `eval_task` and `checkpoint_at` are None.
      checkpoint_at: Function (integer --> boolean) telling, for step n, whether
          that step should have its checkpoint saved. If None, the default is
          periodic checkpointing at `task.n_steps_per_checkpoint`.
      eval_at: Function (integer --> boolean) that says, for training step n,
          whether that step should run evals. If None, run when checkpointing.
    """
        self._task = task
        self._model = model
        self._model_in_training = tl.Serial(model, task.loss_layer)
        self._eval_model = model if eval_model is None else eval_model
        self._eval_task = eval_task
        self._rjust_len = max([0] +
                              [len(name) for name in eval_task.metric_names])

        self._output_dir = os.path.expanduser(
            output_dir) if output_dir else None
        if output_dir is not None:
            tf.io.gfile.makedirs(output_dir)
        default_fn = _at_step_1_and_periodically_at(
            task.n_steps_per_checkpoint)
        self._checkpoint_at = checkpoint_at or default_fn
        self._eval_at = eval_at or default_fn
        if eval_task is None:
            self._eval_at = _never
        self._step = 0

        batch_signature = shapes.signature(task.sample_batch)
        self._batch_signature = batch_signature
        # Initialize the model and the optimizer; discard the return values
        # (model weights/state, optimizer slots/params), since they're available
        # from the model and optimizer objects.
        _, _ = self._model_in_training.init(batch_signature)
        _, _ = task.optimizer.tree_init(self._model_in_training.weights)

        self._gradients_and_state_fn = (
            fastmath.jit(
                fastmath.value_and_grad(
                    self._model_in_training.pure_fn,
                    argnums=1,  # arg1 of pure_fn: weights
                    has_aux=True)))  # return (loss, state), gradients

        if eval_task is not None:
            model_with_metrics = _model_with_metrics(self._eval_model,
                                                     eval_task)
            self._eval_weights = model_with_metrics.weights[
                1]  # just the eval part
            self._eval_state = model_with_metrics.state[
                1]  # just the eval part
            self._metrics_fn = fastmath.jit(model_with_metrics.pure_fn)