예제 #1
0
def train(output_dir,
          model=gin.REQUIRED,
          loss_fn=tl.WeightedCategoryCrossEntropy(),
          inputs=trax_inputs.batcher,
          optimizer=trax_opt.Adafactor,
          lr_schedule_fn=lr.multifactor,
          trainer_class=Trainer,
          steps=1000,
          checkpoints_at=None,
          permanent_checkpoints_at=None,
          eval_steps=10,
          eval_frequency=100,
          permanent_checkpoint_frequency=None,
          random_seed=None,
          save_graphs=True,
          metrics=None,
          checkpoint_highest=None,
          checkpoint_lowest=None,
          use_loop=True,
          loss_chunk_size=0,
          use_memory_efficient_trainer=False):
    """Train the model on the inputs.

  Args:
    output_dir: Directory where to put the logs and checkpoints.
    model: The model to train as a callable returning 2 callables, an init_fn
      and apply_fn.
    loss_fn: callable with signature: weights, trax.inputs.Inputs, model, state,
      rng -> loss.
    inputs: callable returning trax.inputs.Inputs.
    optimizer: The optimizer (see optimizers/base.py for signature).
    lr_schedule_fn: A learning rate schedule function, that when called returns
      a function from step to learning rate (a float).
    trainer_class: The trainer class to use.
    steps: int, total number of training steps.
    checkpoints_at: list of integers. Save a checkpoint for each training step
      in the list.
    permanent_checkpoints_at: list of integers. Save a permanent checkpoint for
      each training step in the list.
    eval_steps: int, num of steps per evaluation. If None or 0, eval disabled.
    eval_frequency: int, how often to run evaluation (every eval_frequency
      steps). If None or 0, eval disabled.
    permanent_checkpoint_frequency: int, how often to save permanent checkpoints
      (every permanent_checkpoint_frequency steps).
    random_seed: the random seed to use; time/os dependent if None (default).
    save_graphs: bool, if True, save computation graph to file.
    metrics: optionally override the default metrics dictionary.
    checkpoint_highest: save the checkpoint highest at this metric.
    checkpoint_lowest: save the checkpoint lowest at this metric.
    use_loop: whether to use training.Loop instead of Trainer.
    loss_chunk_size: int, if > 0 chunk loss into these sizes to save memory.
    use_memory_efficient_trainer: whether to use memory-efficient trainer.

  Returns:
    trax.TrainerState or training.Loop if use_loop is True
  """
    if (permanent_checkpoint_frequency is not None
            and permanent_checkpoints_at is not None):
        raise ValueError('Only one of ["permanent_checkpoint_frequency", '
                         '"permanent_checkpoints_at"] should be set.')
    if use_loop:
        n_devices = num_devices() or fastmath.device_count()

        # Prepare the training task.
        # Inputs is either an Inputs instance or a function that returns it.
        if callable(
                inputs):  # If we pass a function, e.g., through gin, call it.
            inputs = inputs()
        opt = optimizer if use_memory_efficient_trainer else optimizer()
        train_task = training.TrainTask(
            inputs.train_stream(n_devices),
            loss_layer=loss_fn,
            optimizer=opt,
            lr_schedule=lr_schedule_fn(),
            n_steps_per_checkpoint=eval_frequency,
            n_steps_per_permanent_checkpoint=permanent_checkpoint_frequency)

        # Prepare the evaluation.
        metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
        names, metrics = zip(*metrics_dict.items())
        eval_task = training.EvalTask(inputs.eval_stream(n_devices),
                                      metrics,
                                      metric_names=names,
                                      n_eval_batches=eval_steps)

        # Prepare the training loop.
        checkpoint_at = None
        if checkpoints_at is not None:
            checkpoint_at = lambda step: step in checkpoints_at
        permanent_checkpoint_at = None
        if permanent_checkpoints_at is not None:
            permanent_checkpoint_at = (
                lambda step: step in permanent_checkpoints_at)
        loop = training.Loop(
            model(mode='train'), [train_task],
            eval_model=model(mode='eval'),
            eval_tasks=[eval_task],
            output_dir=output_dir,
            checkpoint_at=checkpoint_at,
            permanent_checkpoint_at=permanent_checkpoint_at,
            n_devices=n_devices,
            loss_chunk_size=loss_chunk_size,
            use_memory_efficient_trainer=use_memory_efficient_trainer,
            random_seed=random_seed)

        steps_to_go = steps - loop.step
        if steps_to_go <= 0:
            log('Stop training, already reached the total training steps %d' %
                steps)
            return loop

        # Train and return the loop.
        loop.run(steps_to_go)
        return loop

    n_devices = num_devices()
    trainer = trainer_class(model,
                            loss_fn,
                            optimizer,
                            lr_schedule_fn(),
                            inputs,
                            output_dir,
                            random_seed=random_seed,
                            n_devices=n_devices,
                            checkpoints_at=checkpoints_at,
                            metrics=metrics,
                            checkpoint_lowest=checkpoint_lowest,
                            checkpoint_highest=checkpoint_highest)

    epoch_steps = [steps]  # Only training if eval_frequency is 0 or None
    if eval_frequency and eval_steps > 0:
        epoch_steps = itertools.chain(
            [
                1,  # first epoch only 1 step
                eval_frequency - 1
            ],
            itertools.repeat(eval_frequency))
    trainer.log_step('Starting training using %d devices' % trainer.n_devices)
    trainer.print_n_weights()

    try:
        for epoch_steps in epochs(steps, trainer.step, epoch_steps):
            trainer.train_epoch(epoch_steps, eval_steps)

            # Bookkeeping we do at the first step
            if trainer.step == 1:
                # Save computation graph (single-device only for now)
                if (save_graphs and fastmath.is_backend(fastmath.Backend.JAX)):
                    trainer.save_computation_graphs()

                # Save Gin config
                trainer.save_gin()

        trainer.log_step('Training done')
    except Exception as e:
        raise e
    finally:
        trainer.close()
    return trainer.state
예제 #2
0
    def __init__(
            self,
            task,
            value_body=None,
            value_optimizer=None,
            value_lr_schedule=lr.multifactor,
            value_batch_size=64,
            value_train_steps_per_epoch=500,
            value_evals_per_epoch=1,
            value_eval_steps=1,
            exploration_rate=functools.partial(
                lr.multifactor,
                factors='constant * decay_every',
                constant=1.,  # pylint: disable=redefined-outer-name
                decay_factor=0.99,
                steps_per_decay=1,
                minimum=0.1),
            n_eval_episodes=0,
            only_eval=False,
            n_replay_epochs=1,
            max_slice_length=1,
            sync_freq=1000,
            scale_value_targets=True,
            output_dir=None,
            **kwargs):
        """Configures the value trainer.

    Args:
      task: RLTask instance, which defines the environment to train on.
      value_body: Trax layer, representing the body of the value model.
          functions and eval functions (a.k.a. metrics) are considered to be
          outside the core model, taking core model output and data labels as
          their two inputs.
      value_optimizer: the optimizer to use to train the policy model.
      value_lr_schedule: learning rate schedule to use to train the policy.
      value_batch_size: batch size used to train the policy model.
      value_train_steps_per_epoch: how long to train policy in each RL epoch.
      value_evals_per_epoch: number of policy trainer evaluations per RL epoch
          - only affects metric reporting.
      value_eval_steps: number of policy trainer steps per evaluation - only
          affects metric reporting.
      exploration_rate: exploration rate schedule - used in the policy method.
      n_eval_episodes: number of episodes to play with policy at
        temperature 0 in each epoch -- used for evaluation only
      only_eval: If set to True, then trajectories are collected only for
        for evaluation purposes, but they are not recorded.
      n_replay_epochs: Number of last epochs to take into the replay buffer;
          only makes sense for off-policy algorithms.
      max_slice_length: the maximum length of trajectory slices to use; it is
          the second dimenions of the value network output:
          (batch, max_slice_length, number of actions)
          Higher max_slice_length implies that the network has to predict more
          values into the future.
      sync_freq: frequency when to synchronize the target
        network with the trained network. This is necessary for training the
        network on bootstrapped targets, e.g. using n-step returns.
      scale_value_targets: If `True`, scale value function targets by
          `1 / (1 - gamma)`. We are trying to fix the problem with very large
          returns in some games in a way which does not introduce an additional
          hyperparameters.
      output_dir: Path telling where to save outputs (evals and checkpoints).
      **kwargs: arguments for the superclass RLTrainer.
    """
        super(ValueAgent, self).__init__(task,
                                         n_eval_episodes=n_eval_episodes,
                                         output_dir=output_dir,
                                         **kwargs)
        self._value_batch_size = value_batch_size
        self._value_train_steps_per_epoch = value_train_steps_per_epoch
        self._value_evals_per_epoch = value_evals_per_epoch
        self._value_eval_steps = value_eval_steps
        self._only_eval = only_eval
        self._max_slice_length = max_slice_length
        self._policy_dist = distributions.create_distribution(
            task.action_space)
        self._n_replay_epochs = n_replay_epochs

        self._exploration_rate = exploration_rate()
        self._sync_at = (lambda step: step % sync_freq == 0)

        if scale_value_targets:
            self._value_network_scale = 1 / (1 - self._task.gamma)
        else:
            self._value_network_scale = 1

        value_model = functools.partial(models.Quality,
                                        body=value_body,
                                        n_actions=self.task.action_space.n)

        self._value_eval_model = value_model(mode='eval')
        self._value_eval_model.init(self._value_model_signature)
        self._value_eval_jit = tl.jit_forward(self._value_eval_model.pure_fn,
                                              fastmath.device_count(),
                                              do_mean=False)

        # Inputs to the value model are produced by self._values_batches_stream.
        self._inputs = data.inputs.Inputs(
            train_stream=lambda _: self.value_batches_stream())

        # This is the value Trainer that will be used to train the value model.
        # * inputs to the trainer come from self.value_batches_stream
        # * outputs, targets and weights are passed to self.value_loss
        self._value_trainer = supervised.Trainer(
            model=value_model,
            optimizer=value_optimizer,
            lr_schedule=value_lr_schedule(),
            loss_fn=self.value_loss,
            inputs=self._inputs,
            output_dir=output_dir,
            metrics={
                'value_loss': self.value_loss,
                'value_mean': self.value_mean,
                'returns_mean': self.returns_mean
            })
        value_batch = next(self.value_batches_stream())
        self._eval_model = tl.Accelerate(value_model(mode='collect'),
                                         n_devices=1)
        self._eval_model.init(shapes.signature(value_batch))
        if self._task._initial_trajectories == 0:
            self._task.remove_epoch(0)
            self._collect_trajectories()
예제 #3
0
 def _get_conditionally_synced_rng(self):
     if self._sync and fastmath.device_count() > 1:
         return fastmath.psum(self.rng, 'batch')
     else:
         return self.rng
예제 #4
0
    def __init__(self,
                 task,
                 value_model=None,
                 value_optimizer=None,
                 value_lr_schedule=lr.multifactor,
                 value_batch_size=64,
                 value_train_steps_per_epoch=500,
                 value_evals_per_epoch=1,
                 value_eval_steps=1,
                 n_shared_layers=0,
                 added_policy_slice_length=0,
                 n_replay_epochs=1,
                 scale_value_targets=False,
                 q_value=False,
                 q_value_aggregate='logsumexp',
                 q_value_temperature=1.0,
                 q_value_n_samples=1,
                 q_value_normalization=False,
                 **kwargs):  # Arguments of PolicyAgent come here.
        """Configures the actor-critic trainer.

    Args:
      task: `RLTask` instance to use.
      value_model: Model to use for the value function.
      value_optimizer: Optimizer to train the value model.
      value_lr_schedule: lr schedule for value model training.
      value_batch_size: Batch size for value model training.
      value_train_steps_per_epoch: Number of steps are we using to train the
          value model in each epoch.
      value_evals_per_epoch: Number of value trainer evaluations per RL epoch.
          Every evaluation, we also synchronize the weights of the target
          network.
      value_eval_steps: Number of value trainer steps per evaluation; only
          affects metric reporting.
      n_shared_layers: Number of layers to share between value and policy
          models.
      added_policy_slice_length: How much longer should slices of
          trajectories be for policy than for value training; this
          is useful for TD calculations and only affect the length
          of elements produced for policy batches; value batches
          have maximum length set by `max_slice_length` in `**kwargs`.
      n_replay_epochs: Number of last epochs to take into the replay buffer;
          only makes sense for off-policy algorithms.
      scale_value_targets: If `True`, scale value function targets by
          `1 / (1 - gamma)`.
      q_value: If `True`, use Q-values as baselines.
      q_value_aggregate: How to aggregate Q-values. Options: 'mean', 'max',
        'softmax', 'logsumexp'.
      q_value_temperature: Temperature parameter for the 'softmax' and
        'logsumexp' aggregation methods.
      q_value_n_samples: Number of samples to average over when calculating
          baselines based on Q-values.
      q_value_normalization: How to normalize Q-values before aggregation.
          Allowed values: 'std', 'abs', `None`. If `None`, don't normalize.
      **kwargs: Arguments for `PolicyAgent` superclass.
    """
        self._n_shared_layers = n_shared_layers
        self._value_batch_size = value_batch_size
        self._value_train_steps_per_epoch = value_train_steps_per_epoch
        self._value_evals_per_epoch = value_evals_per_epoch
        self._value_eval_steps = value_eval_steps

        # The 2 below will be initalized in super.__init__ anyway, but are needed
        # to construct value batches which are needed before PolicyAgent init
        # since policy input creation calls the value model -- hence this code.
        self._task = task
        self._max_slice_length = kwargs.get('max_slice_length', 1)
        self._added_policy_slice_length = added_policy_slice_length
        self._n_replay_epochs = n_replay_epochs
        task.set_n_replay_epochs(n_replay_epochs)

        if scale_value_targets:
            self._value_network_scale = 1 / (1 - self._task.gamma)
        else:
            self._value_network_scale = 1

        self._q_value = q_value
        self._q_value_aggregate = q_value_aggregate
        self._q_value_temperature = q_value_temperature
        self._q_value_n_samples = q_value_n_samples
        self._q_value_normalization = q_value_normalization

        is_discrete = isinstance(self._task.action_space, gym.spaces.Discrete)
        self._is_discrete = is_discrete
        self._vocab_size = None
        self._sample_all_discrete_actions = False
        if q_value and is_discrete:
            self._vocab_size = self.task.action_space.n
            # TODO(lukaszkaiser): the code below is specific to AWR, move it.
            # If n_samples = n_actions, we'll take them all in actor and reweight.
            if self._q_value_n_samples == self._vocab_size:
                # TODO(lukaszkaiser): set this explicitly once it's in AWR Trainer.
                self._sample_all_discrete_actions = True

        if q_value:
            value_model = functools.partial(value_model,
                                            inject_actions=True,
                                            is_discrete=is_discrete,
                                            vocab_size=self._vocab_size)
        self._value_eval_model = value_model(mode='eval')
        self._value_eval_model.init(self._value_model_signature)
        self._value_eval_jit = tl.jit_forward(self._value_eval_model.pure_fn,
                                              fastmath.device_count(),
                                              do_mean=False)

        # Initialize policy training.
        super().__init__(task, **kwargs)

        # Initialize training of the value function.
        value_output_dir = kwargs.get('output_dir', None)
        if value_output_dir is not None:
            value_output_dir = os.path.join(value_output_dir, 'value')
            # If needed, create value_output_dir and missing parent directories.
            if not tf.io.gfile.isdir(value_output_dir):
                tf.io.gfile.makedirs(value_output_dir)
        self._value_inputs = data.inputs.Inputs(
            train_stream=lambda _: self.value_batches_stream())
        self._value_trainer = supervised.Trainer(
            model=value_model,
            optimizer=value_optimizer,
            lr_schedule=value_lr_schedule(),
            loss_fn=tl.L2Loss(),
            inputs=self._value_inputs,
            output_dir=value_output_dir,
            metrics={
                'value_loss': tl.L2Loss(),
                'value_mean': self.value_mean
            })
예제 #5
0
파일: trainer.py 프로젝트: yliu45/trax
    def __init__(self,
                 blocks,
                 loss_layer,
                 optimizer_fn,
                 n_devices=None,
                 memoize_jit=True):
        """Creates a ReversibleSerialTrainer and the needed optimizers.

    This trainer performs updates equivalent to using the default Trainer on::

      tl.Serial(blocks + [loss_layer]).

    It is more memory-efficient though since weights are stored on CPU and only
    sent to accelerator layer-by-layer. Blocks are pairs consisting of a list
    of standard (arbitrary) layers and a list of reversible layers which help
    save memory thanks to being reversible.

    Args:
      blocks: A list of pairs of lists of standard and reversible layers.
      loss_layer: The final layer of the model; it can have trainable weights
        but should end with a loss: it is required to produce a scalar output.
      optimizer_fn: A function to create the optimizer, e.g., `optimizers.Adam`.
      n_devices: An optional integer, number of accelerator devices to use;
        by default, all available accelerators will be used.
      memoize_jit: Whether to memoize JITed functions; this significantly speeds
        up XLA compilation of larger models, but it uses `repr(layer)` as keys
        to memoize so it could fail if two layers with different functionality
        had the same string representaion. We have not encountered such case
        yet so this is turned on by default, but consider turning it off or
        reviewing your model if you use custom layers and encounter a problem.
    """
        self._blocks = [(tl.Serial(std), rev) for (std, rev) in blocks]
        self._loss_layer = loss_layer
        self._optimizer_fn = optimizer_fn
        self._n_devices = n_devices or fastmath.device_count()
        self._n_layers = 1 + sum([len(revs) + 1 for (_, revs) in self._blocks])
        self._n_steps_per_log = 100  # Log layers and stats every 100 steps.
        self._jit_memory = {} if memoize_jit else None
        self._jit_per_device_rngs = fastmath.jit(self._per_device_rngs,
                                                 backend='cpu')

        # Create accelerated versions of layers as pmaped/jited pure_fn.
        self._accelerated_layer_fns = fastmath.nested_map(
            lambda layer: self._pjit(layer.pure_fn, f'fwd {repr(layer)}'),
            self._blocks)

        # Create per-layer optimizers and replicate opt_params.
        def _make_optimizer(layer):
            opt = optimizer_fn()
            opt.tree_init(layer.weights)
            return opt

        self._optimizers = fastmath.nested_map(_make_optimizer, self._blocks)
        self._replicated_opt_params = fastmath.nested_map(
            lambda opt: self._replicate_cpu(opt.opt_params), self._optimizers)

        self._loss_opt = _make_optimizer(loss_layer)
        self._replicated_loss_opt_params = self._replicate_cpu(
            self._loss_opt.opt_params)

        # Forward + backward + optimizer-update functions for all layers.
        # We call them in short FBO for "Forward + Backward + Optimizer update".
        # Reversible layers define a reverse_and_fbo function that also reverses.

        self._fbos = []
        for i, (std_layer, rev_layers) in enumerate(self._blocks):
            (std_opt, rev_opts) = self._optimizers[i]
            std_fbo = _fbo_with_layer_and_opt(std_layer, std_opt,
                                              self._n_devices)
            rev_and_fbos = []
            for layer, opt in zip(rev_layers, rev_opts):
                rev_and_fbo = _reverse_and_fbo_with_layer_and_opt(
                    layer, opt, self._n_devices)
                # The donated args are (outputs, weights, grads) and we can donate
                # them because weights and grads are immediately replaced and in
                # case of reversible layers, the outputs are never used again.
                rev_and_fbos.append(
                    self._pjit(rev_and_fbo,
                               f'rev+bwd {repr(layer)}',
                               donate_argnums=(0, 1, 2)))
            # In standard layers, the inputs cannot be donated as they may be used
            # as outputs for the reversible block below, but weights and grads can.
            jit_std_fbo = self._pjit(std_fbo,
                                     f'bwd {repr(std_layer)}',
                                     donate_argnums=(1, 2))
            self._fbos.append((jit_std_fbo, rev_and_fbos))

        loss_fbo = _fbo_with_layer_and_opt(self._loss_layer, self._loss_opt,
                                           self._n_devices, 'loss')
        self._loss_fbo = self._pjit(loss_fbo, donate_argnums=(1, 2))
예제 #6
0
    def __init__(self, blocks, loss_layer, optimizer_fn, n_devices=None):
        """Creates a ReversibleSerialTrainer and the needed optimizers.

    This trainer performs updates equivalent to using the default Trainer on::

      tl.Serial(blocks + [loss_layer]).

    It is more memory-efficient though since weights are stored on CPU and only
    sent to accelerator layer-by-layer. Blocks are pairs consisting of a list
    of standard (arbitrary) layers and a list of reversible layers which help
    save memory thanks to being reversible.

    Args:
      blocks: A list of pairs of lists of standard and reversible layers.
      loss_layer: The final layer of the model; it can have trainable weights
        but should end with a loss: it is required to produce a scalar output.
      optimizer_fn: A function to create the optimizer, e.g., `optimizers.Adam`.
      n_devices: An optional integer, number of accelerator devices to use;
        by default, all available accelerators will be used.
    """
        self._blocks = [(tl.Serial(std), rev) for (std, rev) in blocks]
        self._loss_layer = loss_layer
        self._optimizer_fn = optimizer_fn
        self._n_devices = n_devices or fastmath.device_count()
        self._n_layers = 1 + sum([len(revs) + 1 for (_, revs) in self._blocks])
        self._n_steps_per_log = 100  # Log layers and stats every 100 steps.

        # Create accelerated versions of layers as pmaped/jited pure_fn.
        self._accelerated_layer_fns = fastmath.nested_map(
            lambda layer: self._pjit(layer.pure_fn), self._blocks)

        # Create per-layer optimizers and replicate opt_params.
        def _make_optimizer(layer):
            opt = optimizer_fn()
            opt.tree_init(layer.weights)
            return opt

        self._optimizers = fastmath.nested_map(_make_optimizer, self._blocks)
        self._replicated_opt_params = fastmath.nested_map(
            lambda opt: self._replicate(opt.opt_params), self._optimizers)

        self._loss_opt = _make_optimizer(loss_layer)
        self._replicated_loss_opt_params = self._replicate(
            self._loss_opt.opt_params)

        # Forward + backward + optimizer-update functions for all layers.
        # We call them in short FBO for "Forward + Backward + Optimizer update".
        # Reversible layers define a reverse_and_fbo function that also reverses.

        self._fbos = []
        for i, (std_layer, rev_layers) in enumerate(self._blocks):
            (std_opt, rev_opts) = self._optimizers[i]
            std_fbo = _fbo_with_layer_and_opt(std_layer, std_opt,
                                              self._n_devices)
            rev_and_fbos = []
            for layer, opt in zip(rev_layers, rev_opts):
                rev_and_fbos.append(
                    self._pjit(_reverse_and_fbo_with_layer_and_opt(
                        layer, opt, self._n_devices),
                               donate_argnums=(1, )))
            self._fbos.append((self._pjit(std_fbo,
                                          donate_argnums=(1, )), rev_and_fbos))

        loss_fbo = _fbo_with_layer_and_opt(self._loss_layer, self._loss_opt,
                                           self._n_devices, 'loss')
        self._loss_fbo = self._pjit(loss_fbo, donate_argnums=(1, ))
예제 #7
0
 def __init__(self, layer, n_devices=None):
   super().__init__(n_in=layer.n_in, n_out=layer.n_out)
   self._sublayers = [layer]
   self._n_devices = n_devices or fastmath.device_count()
   self._jit_pure_fn = jit_forward(
       layer.pure_fn, self._n_devices, do_mean=False)
예제 #8
0
    def forward(self, xs):
        self._validate_forward_inputs(xs)
        (step, layers_state) = self.state
        # Get N+1 rngs, N for running layers and one extra.
        rngs = _split_rngs(self.rng, self._n_layers + 1)
        rng0, rngs = rngs[0], rngs[1:]
        if not self.sublayers:  # No-op: leave args unchanged.
            self.state = (step + 1, layers_state)
            return xs

        # Prepare the stack and do some safety checks as in the parent class.
        stack = xs
        new_state = []
        n_layers = self._n_layers
        weights = self.weights
        if n_layers != 1 and len(weights) != n_layers:
            raise ValueError(
                'number of weights ({}) not equal to number of layers '
                '({})'.format(len(weights), n_layers))
        if n_layers != 1 and len(layers_state) != n_layers:
            raise ValueError(
                'length of state ({}) not equal to number of layers '
                '({})'.format(len(layers_state), n_layers))

        # TODO(chowdhery): try different strategies, also try running not all
        # layers backwards by using fastmath.stop_gradient where needed.

        # Calculate how many layers to run forward.
        if self._mode == 'train':
            # warmup goes from 1.0 at start to 0.0 at skipping_warmup_steps and after
            w_steps = float(self._skipping_warmup_steps)
            f_step = jnp.array(step, dtype=jnp.float32)
            warmup = jnp.maximum(0.0, (w_steps - f_step) / w_steps)
            # low is the minimum number of layers to *not* skip, from n_layers to 0
            low = warmup * float(n_layers)
            # high should be so that (high - n_layers) / high = 1.0 - skip_fraction
            # because (high - n_layers) / high is the probability we're not skipping
            # (after warmup); so high - n_layers = high - high * skip_fraction
            high = float(n_layers) / self._skip_fraction
            # We want the same rng0 on all cores.
            if fastmath.device_count() > 1:
                rng0 = fastmath.psum(rng0, 'batch')
            n_forward_layers = random.uniform(rng0, (), jnp.float32, low, high)
        else:
            n_forward_layers = float(n_layers)
        # Run layers skipping after a certain number.
        cur_layer_idx = 0.0
        for layer, p, s, rng in zip(self.sublayers, weights, layers_state,
                                    rngs):
            inputs = _inputs_from_stack(layer, stack)

            def CondF(t):
                o, s = layer.pure_fn(t[0], t[1], t[2], t[3])  # pylint: disable=cell-var-from-loop
                return o, t[1], s, t[3]

            outputs, _, s, _ = fastmath.cond(
                fastmath.lt(cur_layer_idx, n_forward_layers), CondF,
                lambda x: x, (inputs, p, s, rng))
            stack = _outputs_onto_stack(layer, outputs, stack)
            new_state.append(s)
            cur_layer_idx += 1.0
        self.state = (step + 1, new_state)
        return stack