Ejemplo n.º 1
0
    def test_run_reversible_same_as_default_extended(self):
        """Runs the reversible trainer, check results are the same as default."""
        inputs_batch = np.arange(8).reshape((2, 4))
        targets_batch = 2 * inputs_batch
        labeled_batch = (inputs_batch, targets_batch,
                         np.ones_like(targets_batch))
        # We want to test rng propagation too, so adding some dropout layers.
        first_layer = tl.Serial(tl.Embedding(9, 4), tl.Dropout(0.5), tl.Dup())
        rev_layers1 = [
            tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.2)),
            tl.ReversibleSwap(),
            tl.ReversibleHalfResidual(tl.Dropout(0.5), tl.Dense(4)),
            tl.ReversibleSwap()
        ]
        mid_layer = tl.Serial(tl.Add(), tl.Dense(4), tl.Dup())
        rev_layers2 = [
            tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.3)),
            tl.ReversibleSwap()
        ]
        loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(19), tl.Dropout(0.3),
                               tl.LogSoftmax(), tl.CrossEntropyLoss())
        model = tl.Serial([first_layer] + rev_layers1 + [mid_layer] +
                          rev_layers2 + [loss_layer])
        rng_init = fastmath.random.get_prng(12)
        model.init(labeled_batch, rng=rng_init)
        optimizer_fn = optimizers.Adam  # to test slots

        # Make 3 steps with the original trainer.
        optimizer = optimizer_fn()
        optimizer.tree_init(model.weights)
        trainer = optimizers.Trainer(model, optimizer)
        rng_step1 = fastmath.random.get_prng(7)
        rng_step2 = fastmath.random.get_prng(8)
        rng_step3 = fastmath.random.get_prng(9)
        trainer.one_step(labeled_batch, rng_step1)
        trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02)
        trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03)
        first_layer_weights1 = first_layer.weights
        rev_layer12_weights1 = rev_layers1[2].weights
        mid_layer_weights1 = mid_layer.weights
        rev_layer20_weights1 = rev_layers2[0].weights
        loss_layer_weights1 = loss_layer.weights

        # Now make 3 steps with reversible trainer.
        model.init(labeled_batch, rng=rng_init)
        # TODO(lukaszkaiser): this test seems to fail with memoize_jit, why?
        trainer = optimizers.ReversibleSerialTrainer(
            [(first_layer.sublayers, rev_layers1),
             (mid_layer.sublayers, rev_layers2)],
            loss_layer,
            optimizer_fn,
            memoize_jit=False)
        trainer.one_step(labeled_batch, rng_step1)
        trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02)
        trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03)

        # Check that weights end up the same.
        self._assert_all_equal(loss_layer_weights1, loss_layer.weights)
        self._assert_all_equal(rev_layer20_weights1, rev_layers2[0].weights)
        self._assert_all_equal(mid_layer_weights1, mid_layer.weights)
        self._assert_all_equal(rev_layer12_weights1, rev_layers1[2].weights)
        self._assert_all_equal(first_layer_weights1, first_layer.weights)
Ejemplo n.º 2
0
    def __init__(self,
                 model,
                 tasks,
                 eval_model=None,
                 eval_tasks=None,
                 output_dir=None,
                 checkpoint_at=None,
                 eval_at=None,
                 n_devices=None,
                 random_seed=None):
        """Configures a training `Loop`, including a random initialization.

    Args:
      model: Trax layer, representing the core model to be trained. Loss
          functions and eval functions (a.k.a. metrics) are considered to be
          outside the core model, taking core model output and data labels as
          their two inputs.
      tasks: List of TrainTask instances, which define the training data, loss
          function, and optimizer to be used in respective tasks in this
          training loop.
      eval_model: Optional Trax layer, representing model used for evaluation,
        e.g., with dropout turned off. If None, the training model (model)
        will be used.
      eval_tasks: List of EvalTask instances or None. If None, don't do any
        evals.
      output_dir: Path telling where to save outputs (evals and checkpoints).
          Can be None if both `eval_task` and `checkpoint_at` are None.
      checkpoint_at: Function (integer --> boolean) telling, for step n, whether
          that step should have its checkpoint saved. If None, the default is
          periodic checkpointing at `task.n_steps_per_checkpoint`.
      eval_at: Function (integer --> boolean) that says, for training step n,
          whether that step should run evals. If None, run when checkpointing.
      n_devices: integer or None, the number of devices for this computation.
      random_seed: the random seed to use; time/os dependent if None (default).
    """
        self._is_chief, self._n_hosts, self._n_devices, self._rng = (
            init_host_and_devices(n_devices, random_seed))

        # Handle single task case without lists too.
        if not isinstance(tasks, (list, tuple)):
            tasks = [tasks]

        assert len(tasks) == 1, 'Multitask training not supported yet.'
        task = tasks[0]
        if eval_tasks is None:
            eval_task = None
        else:
            assert len(
                eval_tasks) == 1, 'Multitask training not supported yet.'
            eval_task = eval_tasks[0]

        self._task = task
        self._model = model
        self._eval_model = eval_model or model
        default_at = (_at_step_1_and_every_nth_step(
            self._task.n_steps_per_checkpoint))
        if output_dir is not None:
            self._output_dir = os.path.expanduser(output_dir)
            tf.io.gfile.makedirs(self._output_dir)
        else:
            self._output_dir = None

        # Prepare training components.
        self._step = 0
        self._checkpoint_at = checkpoint_at or default_at
        self._batch_signature = shapes.signature(self._task.sample_batch)
        self._model_in_training = tl.Serial(self._model, self._task.loss_layer)

        # Initialize using the given random seed.
        # NOTE: If `random_seed` is `None` then `self._rng` will be different on
        # different hosts, leading to different weights on the different hosts.
        self._model_in_training.rng = self.new_rng()
        self._model_in_training.init(self._batch_signature)
        self._eval_model.rng = self.new_rng()
        self._eval_model.init(self._batch_signature)

        # To handle the above case (i.e. random_seed = None), we psum the weights
        # and state and average them.
        # NOTE: This adds time (how much?) so we prefer not to do it if it is
        # unnecessary, i.e. random_seed was set.
        if random_seed is None and self._n_hosts > 1:
            logging.info('Syncing weights/state across %d hosts.',
                         self._n_hosts)
            self._sync_weights_and_state_across_hosts()

        # Restore from checkpoint if there's one after initializing optimizer slots.
        self._task.optimizer.tree_init(self._model_in_training.weights)
        self.load_checkpoint()

        # Create the optimizer for the training loss function.
        self._trainer = optimizers.Trainer(self._model_in_training,
                                           self._task.optimizer)

        # Prepare eval components.
        if eval_task is None:
            self._eval_at = _never
        else:
            self._eval_task = eval_task
            self._eval_at = eval_at or default_at
            metric_name_lengths = [
                len(name) for name in self._eval_task.metric_names
            ]
            self._rjust_len = max([len(self._task.loss_layer.name)] +
                                  metric_name_lengths)
            model_with_metrics = (_model_with_metrics(self._eval_model,
                                                      self._eval_task))
            # Keep self._eval_{weights/state} replicated.
            self._eval_weights = self._for_n_devices(
                model_with_metrics.weights[1])  # just the eval part
            self._eval_state = self._for_n_devices(
                model_with_metrics.state[1])  # just the eval part
            self._metrics_fn = _accelerate_model_with_metrics(
                model_with_metrics, self.n_devices)
            if self._output_dir is None:
                _log(
                    'Will not write evaluation metrics, because output_dir is None.'
                )