Example #1
0
def _save_replicated(opt_state, step, history, n_devices, output_dir, keep):
  """Save state but given a possibly replicated opt_state."""
  if n_devices > 1:
    unreplicate = lambda x: x.mean(0)
    opt_state = layers.nested_map(opt_state, unreplicate)
    save_state(State(params=opt_state, step=step, history=history),
               output_dir, keep=keep)
Example #2
0
    def reset(self, output_dir):
        """Reset the model parameters.

    Restores the parameters from the given output_dir if a checkpoint exists,
    otherwise randomly initializes them.

    Does not re-jit the model.

    Args:
      output_dir: Output directory.
    """
        self._output_dir = output_dir
        gfile.makedirs(output_dir)
        # Create summary writers and history.
        self._train_sw = jaxboard.SummaryWriter(
            os.path.join(output_dir, "train"))
        self._eval_sw = jaxboard.SummaryWriter(os.path.join(
            output_dir, "eval"))

        # Reset the train and eval streams.
        self._train_stream = self._inputs.train_stream()
        # TODO(lukaszkaiser): add an option to evaluate exactly on the full eval
        #   set by adding a padding and stopping the stream when too large.
        self._eval_stream = _repeat_stream(self._inputs.eval_stream)
        self._train_eval_stream = _repeat_stream(
            self._inputs.train_eval_stream)

        # Restore the training state.
        state = restore_state(output_dir)
        self._step = state.step or 0
        history = state.history
        self._lr_fn = self._lr_schedule(history)
        self._history = history
        if state.opt_state:
            opt_state = state.opt_state
            model_state = state.model_state
        else:
            opt_state, model_state = self._initialize()
            model_state = layers.nested_map(model_state, self._maybe_replicate)
        self._opt_state = OptState(
            *layers.nested_map(opt_state, self._maybe_replicate))
        self._model_state = model_state
        if not state.opt_state:
            self._maybe_save_state(keep=False)

        self.update_optimizer_params()
Example #3
0
def _print_n_params(opt_state, n_devices, step):
    """Print out the number of parameters."""
    sizes = layers.sizes(opt_state.params)
    if n_devices > 1:
        unreplicate = lambda x: x[0]
        single_params = layers.nested_map(opt_state.params, unreplicate)
        sizes = layers.sizes(single_params)
    total_size = layers.nested_reduce(sizes, sum)
    step_log(step, "Total trainable parameters size: %d" % total_size)
Example #4
0
def _save_replicated(opt_state, step, history, n_devices, output_dir, keep):
  """Save state but given a possibly replicated opt_state."""
  if n_devices > 1:
    first_replica = lambda x: x[0]
    opt_state = layers.nested_map(opt_state, first_replica)
  # This line, while optional, allows JAX to transfer arrays from the device to
  # the host in parallel, which is particularly important for cloud TPU.
  if backend.get_name() == "jax":
    opt_state = jax.device_get(opt_state)
  save_state(State(params=opt_state, step=step, history=history),
             output_dir, keep=keep)
Example #5
0
    def reset(self, output_dir):
        """Reset the model parameters.

    Restores the parameters from the given output_dir if a checkpoint exists,
    otherwise randomly initializes them.

    Does not re-jit the model.

    Args:
      output_dir: Output directory.
    """
        self._output_dir = output_dir
        gfile.makedirs(output_dir)
        # Create summary writers and history.
        self._train_sw = jaxboard.SummaryWriter(
            os.path.join(output_dir, "train"))
        self._eval_sw = jaxboard.SummaryWriter(os.path.join(
            output_dir, "eval"))

        # Reset the training stream.
        self._train_stream = self._inputs.train_stream()

        # Restore the training state.
        state = restore_state(output_dir)
        self._step = state.step or 0
        history = state.history
        self._lr_fn = self._lr_schedule(history)
        self._history = history
        if state.opt_state:
            opt_state = state.opt_state
            model_state = state.model_state
        else:
            opt_state, model_state = self._initialize()
            model_state = layers.nested_map(model_state, self._maybe_replicate)
        self._opt_state = OptState(
            *layers.nested_map(opt_state, self._maybe_replicate))
        self._model_state = model_state
        if not state.opt_state:
            self._maybe_save_state(keep=False)

        self.update_learning_rate()
Example #6
0
    def predict(x, params=(), state=(), rng=None):
        """Predict function jited and parallelized as requested."""
        pred, state = mapped_predict(reshape_by_device(x, n_devices), params,
                                     state, jax_random.split(rng, n_devices))

        # Need to reduce the [device, per-device-batch, ...] tensors back to
        # a [batch, ...] tensor. The tensors may be nested.
        def combine(x):
            batch_size = x.shape[0] * x.shape[1]
            return np.reshape(x, [batch_size] + list(x.shape[2:]))

        return layers.nested_map(pred, combine), state
Example #7
0
    def _train_step(self, next_train_batch):
        """Run one training step and update self._opt_state."""
        # Calculate the current learning rate.
        opt_param_updates = layers.nested_map(
            self.optimizer_params,
            lambda x: self._maybe_replicate(np.array(x)))
        opt_state = self._opt_state
        opt_state.opt_params.update(opt_param_updates)

        # Run the update.
        (params, slots), self._model_state, self._rngs = self._jit_update_fn(
            self._step, opt_state, next_train_batch, self._model_state,
            self._rngs)
        self._opt_state = opt_state._replace(params=params, slots=slots)
        self._step += 1
Example #8
0
    def predict(x, params=(), state=(), rng=None):
        """Predict function jited and parallelized as requested."""
        pred = mapped_predict(reshape_by_device(x, n_devices), params, state,
                              jax_random.split(rng, n_devices))

        # Need to reduce the [device, per-device-batch, ...] tensors back to
        # a [batch, ...] tensor. The tensors may be nested.
        def combine(x):
            if len(x.shape) > 1:
                batch_size = x.shape[0] * x.shape[1]
                return np.reshape(x, [batch_size] + list(x.shape[2:]))
            # TODO(lukaszkaiser): is returning averages for scalars the right choice?
            # If it is only scalar, return the average.
            return np.mean(x, axis=0)

        return layers.nested_map(pred, combine)
Example #9
0
  def _train_step(self, next_train_batch):
    """Run one training step and update self._opt_state."""
    # Calculate the current optimizer parameters.
    # TODO(pkozakowski): Optimizer parameters get polluted with model state,
    # which doesn't break anything but is weird. Filter it out.
    opt_param_updates = layers.nested_map(
        self.nontrainable_params, lambda x: self._maybe_replicate(np.array(x))
    )
    opt_state = self._opt_state
    opt_state.opt_params.update(opt_param_updates)

    # Run the update.
    (params, slots), self._model_state, self._rngs = self._jit_update_fn(
        self._step, opt_state, next_train_batch, self._model_state, self._rngs)
    self._model_state = self._map_to_state_dicts(self._state_dicts_update)
    self._opt_state = opt_state._replace(params=params, slots=slots)
    self._step += 1
Example #10
0
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
                 output_dir=None,
                 random_seed=None,
                 n_devices=None,
                 save_steps=None,
                 should_save=True,
                 has_weights=False):
        if save_steps is None:
            save_steps = []
        self._save_steps = save_steps
        self._should_save = should_save
        self._has_weights = has_weights
        loss_fn = functools.partial(loss_fn, has_weights=self._has_weights)
        device_count = jax.lib.xla_bridge.device_count()
        n_devices = n_devices or device_count
        # TODO(lukaszkaiser): remove this restriction when possible.
        if n_devices != device_count:
            raise ValueError(
                "Jax cannot work yet with n_devices != all devices: "
                "%d != %d" % (n_devices, device_count))
        self._n_devices = n_devices
        rng = get_random_number_generator_and_set_seed(random_seed)
        inputs = inputs(n_devices)
        self._inputs = inputs

        # Initialize the learning rate to a dummy value. It will be set in reset().
        opt = optimizer(learning_rate=0.0)

        # Setup the model.
        model_train = model(mode="train")
        model_predict_eval = model(mode="eval")

        # Setup state.
        rng, init_rng = jax_random.split(rng)
        self._rngs = jax_random.split(rng, n_devices)
        first_shape = inputs.input_shape[0]
        # If the inputs are a tuple/list, add [None] (batch) to each element.
        if isinstance(first_shape, (list, tuple)):
            model_input_shape = tuple(
                tuple([None] + list(shape)) for shape in inputs.input_shape)
            model_target_shape = tuple(
                tuple([None] + list(shape)) for shape in inputs.target_shape)
        else:  # Otherwise just add [None] to the input shape.
            model_input_shape = tuple([None] + list(inputs.input_shape))
            model_target_shape = tuple([None] + list(inputs.target_shape))
        # Change all None to 1 in input and target shape.
        model_input_shape = layers.nested_map(model_input_shape, lambda x: x
                                              if x else 1)
        model_target_shape = layers.nested_map(model_target_shape, lambda x: x
                                               if x else 1)

        def initialize(input_shape, input_dtype, target_shape, target_dtype,
                       rng):
            """Helper to initialize the model."""
            # Combine inputs and targets on the stack.
            if not isinstance(input_dtype, (list, tuple)):
                input_dtype = [input_dtype]
                input_shape = [input_shape]
            if not isinstance(target_dtype, (list, tuple)):
                target_dtype = [target_dtype]
                target_shape = [target_shape]
            full_type = list(input_dtype) + list(target_dtype)
            full_shape = list(input_shape) + list(target_shape)
            # We need to create a new model instance and not reuse `model_train` here,
            # because `m.initialize` puts cached parameter values in `m` and hence the
            # next call of `m.initialize` will give wrong results.
            params, state = model(mode="train").initialize(
                full_shape, full_type, rng)
            (slots, opt_params) = opt.tree_init(params)
            return (OptState(params, slots, opt_params), state)

        if _is_jit_init():
            # JIT parameter initialization to avoid memory fragmentation
            initialize = backend.jit(initialize, static_argnums=(0, 1, 2, 3))
        self._initialize = lambda: initialize(  # pylint: disable=g-long-lambda
            model_input_shape, self._inputs.input_dtype, model_target_shape,
            self._inputs.target_dtype, init_rng)

        # jit model_predict and update so they're fast
        self._jit_model_predict_eval = _jit_predict_fn(model_predict_eval,
                                                       n_devices)
        self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt,
                                             n_devices)

        self._model_train = model_train
        self._model_predict_eval = model_predict_eval
        self._loss_fn = loss_fn
        self._lr_schedule = lr_schedule

        # Those fields will be set in reset().
        self._output_dir = None
        self._train_sw = None
        self._eval_sw = None
        self._history = None
        self._lr_fn = None
        self._opt_state = None
        self._step = None
        self._model_state = None

        if output_dir is not None:
            self.reset(output_dir)
Example #11
0
def reshape_by_device(x, n_devices):
    """Reshape possibly nested x into a shape [n_devices, ...]."""
    return layers.nested_map(x,
                             lambda x: _reshape_by_device_single(x, n_devices))
Example #12
0
  def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir,
               random_seed=None, n_devices=None, save_steps=None):
    if save_steps is None:
      save_steps = []
    self._save_steps = save_steps
    device_count = jax.lib.xla_bridge.device_count()
    n_devices = n_devices or device_count
    # TODO(lukaszkaiser): remove this restriction when possible.
    if n_devices != device_count:
      raise ValueError("Jax cannot work yet with n_devices != all devices: "
                       "%d != %d" % (n_devices, device_count))
    self._n_devices = n_devices
    rng = get_random_number_generator_and_set_seed(random_seed)
    self._output_dir = output_dir
    gfile.makedirs(output_dir)
    # Create summary writers and history.
    self._train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train"))
    self._eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval"))

    # Create input streams.
    inputs = inputs(n_devices)
    self._inputs = inputs
    self._train_stream = inputs.train_stream()

    # Setup optimizer and model.
    state = restore_state(output_dir)
    history = state.history
    self._lr_fn = lr_schedule(history)
    opt = optimizer(self._lr_fn)

    model_train = model(mode="train")
    model_predict_eval = model(mode="eval")

    # Setup state.
    step = state.step or 0
    rng, init_rng = jax_random.split(rng)
    self._rngs = jax_random.split(rng, n_devices)
    first_shape = inputs.input_shape[0]
    # If the inputs are a tuple/list, add [None] (batch) to each element.
    if isinstance(first_shape, (list, tuple)):
      model_input_shape = tuple(
          tuple([None] + list(shape)) for shape in inputs.input_shape)
    else:  # Otherwise just add [None] to the input shape.
      model_input_shape = tuple([None] + list(inputs.input_shape))
    # Change all None to 1 in input shape.
    model_input_shape = layers.nested_map(
        model_input_shape, lambda x: x if x else 1)
    if state.params:
      params = state.params[0]
      opt_state = state.params
    else:
      params = model_train.initialize(
          model_input_shape, inputs.input_dtype, init_rng)
      opt_state = (params, opt.tree_init(params))
    if n_devices > 1:
      replicate = lambda x: numpy.broadcast_to(x, (n_devices,) + x.shape)
      opt_state = layers.nested_map(opt_state, replicate)

    # jit model_predict and update so they're fast
    self._jit_model_predict_eval = _jit_predict_fn(
        model_predict_eval, n_devices)
    self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices)

    self._step = step
    self._model_train = model_train
    self._model_predict_eval = model_predict_eval
    self._loss_fn = loss_fn
    self._optimizer = optimizer
    self._opt_state = opt_state
    self._history = history
    self._lr_schedule = lr_schedule
Example #13
0
def train(output_dir,
          model=gin.REQUIRED,
          loss_fn=loss,
          inputs=trax_inputs.inputs,
          optimizer=trax_opt.SM3,
          lr_schedule=lr.MultifactorSchedule,
          train_steps=1000,
          save_steps=None,
          eval_steps=10,
          eval_frequency=100,
          n_devices=None,
          random_seed=None,
          run_debug_step=False,
          save_graphs=True,
          save_backward_graph=False):
  """Train the model on the inputs.

  Args:
    output_dir: Directory where to put the logs and checkpoints.
    model: The model to train as a callable returning 2 callables, an init_fn
      and apply_fn.
    loss_fn: callable with signature: params, trax.inputs.Inputs, model, rng
      -> loss.
    inputs: callable returning trax.inputs.Inputs.
    optimizer: The optimizer (see optimizers/base.py for signature).
    lr_schedule: A learning rate schedule as a function that takes history and
      returns a function from step to learning rate (a float).
    train_steps: int, total number of training steps.
    save_steps: list of integers. Keep a model file at each of the supplied save
      steps.
    eval_steps: int, num of steps per evaluation. If None or 0, eval disabled.
    eval_frequency: int, how often to run evaluation (every eval_frequency
      steps). If None or 0, eval disabled.
    n_devices: how many devices to use (if None, default, use all available)
    random_seed: the random seed to use; time/os dependent if None (default).
    run_debug_step: bool, if True, will run the model and loss without @jit for
      one step.
    save_graphs: bool, if True, save computation graph to file.
    save_backward_graph: bool, if True, save backward graph to file too.
  Returns:
    trax.State
  """
  if save_steps is None:
    save_steps = []
  device_count = jax.lib.xla_bridge.device_count()
  n_devices = n_devices or device_count
  # TODO(lukaszkaiser): remove this restriction when possible.
  if n_devices != device_count:
    raise ValueError("Jax cannot work yet with n_devices != all devices: "
                     "%d != %d" % (n_devices, device_count))
  rng = get_random_number_generator_and_set_seed(random_seed)
  gfile.makedirs(output_dir)
  # Create summary writers and history.
  train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train"))
  eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval"))

  inputs = inputs(n_devices)

  # Setup optimizer and model
  state = restore_state(output_dir)
  history = state.history
  lr_fn = lr_schedule(history)
  opt = optimizer(lr_fn)

  model_train = layers.Serial(model(mode="train"))
  model_predict_eval = layers.Serial(model(mode="eval"))

  # Setup state
  step = state.step or 0
  rng, init_rng = jax_random.split(rng)
  rngs = jax_random.split(rng, n_devices)
  first_shape = inputs.input_shape[0]
  # If the inputs are a tuple/list, add [-1] (batch) to each element.
  if isinstance(first_shape, (list, tuple)):
    model_input_shape = tuple(
        [tuple([-1] + list(shape)) for shape in inputs.input_shape])
  else:  # Otherwise just add [-1] to the input shape.
    model_input_shape = tuple([-1] + list(inputs.input_shape))
  if state.params:
    params = state.params[0]
    opt_state = state.params
  else:
    params = model_train.initialize(model_input_shape, init_rng)
    opt_state = (params, opt.tree_init(params))
  if n_devices > 1:
    replicate = lambda x: numpy.broadcast_to(x, (n_devices,) + x.shape)
    opt_state = layers.nested_map(opt_state, replicate)

  # jit model_predict and update so they're fast
  jit_model_predict_eval = _jit_predict_fn(model_predict_eval, n_devices)
  jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices)

  train_stream = inputs.train_stream()
  epoch_steps = [train_steps]  # Only training if eval_frequency is 0 or None.
  if eval_frequency and eval_steps > 0:
    epoch_steps = itertools.chain([1,  # first epoch only 1 step
                                   eval_frequency - 1],
                                  itertools.repeat(eval_frequency))
  step_log(step, "Starting training using %d devices" % n_devices)

  # Non-compiled debug step helps find problems in models easier.
  if run_debug_step:
    debug_loss = loss_fn(params, next(train_stream), model_train, rng)
    step_log(step, "Debug step loss %.8f" % debug_loss)

  for epoch, epoch_steps in epochs(train_steps, epoch_steps):
    # Log separator
    print()

    # Timer
    start_time = time.time()

    for _ in range(epoch_steps):
      # Train
      next_train_batch = next(train_stream)
      if n_devices > 1:  # TODO(lukaszkaiser): use everywhere when possible.
        next_train_batch = reshape_by_device(next_train_batch, n_devices)
      opt_state, rngs = jit_update_fn(step, opt_state, next_train_batch, rngs)
      step += 1

      if step in save_steps:
        _save_replicated(opt_state, step, history, n_devices, output_dir, True)

      # LR log
      if step == 1 or step % 10 == 0:
        train_sw.scalar("training/learning rate",
                        lr_fn(step), step=step)

    # Timer
    epoch_time = time.time() - start_time
    step_log(step, "Ran %d train steps in %0.2f secs" %
             (epoch_steps, epoch_time))
    if epoch_steps > 1:
      train_sw.scalar("training/steps per second",
                      epoch_steps / epoch_time, step=step)

    # Print number of parameters
    if step == 1:
      sizes = layers.sizes(opt_state[0])
      if n_devices > 1:
        unreplicate = lambda x: x.mean(0)
        single_params = layers.nested_map(opt_state[0], unreplicate)
        sizes = layers.sizes(single_params)
      total_size = layers.nested_reduce(sizes, sum)
      step_log(step, "Total trainable parameters size: %d" % total_size)

    # Evaluate in parallel
    evaluate_train_and_eval(
        step=step,
        inputs=inputs,
        predict_fn=functools.partial(jit_model_predict_eval,
                                     params=opt_state[0]),
        eval_steps=eval_steps,
        rng=rng,
        train_sw=train_sw,
        eval_sw=eval_sw,
        history=history)

    # Save computation graph (single-device only for now).
    if save_graphs and step == 1 and n_devices == 1:
      params = opt_state[0]
      # Dump computation graphs to files.
      forward_computation = jax.xla_computation(model_predict_eval)(
          next_train_batch[0], params=params, rng=rng)
      with gfile.GFile(os.path.join(output_dir, "forward.txt"), "w") as f:
        f.write(forward_computation.GetHloText())
      with gfile.GFile(os.path.join(output_dir, "forward.dot"), "w") as f:
        f.write(forward_computation.GetHloDotGraph())
      backward_computation = jax.xla_computation(jit_update_fn)(
          step, opt_state, next_train_batch, rngs)
      with gfile.GFile(os.path.join(output_dir, "backward.txt"), "w") as f:
        f.write(backward_computation.GetHloText())
      if save_backward_graph:  # Backward graphs can be large so we guard it.
        with gfile.GFile(os.path.join(output_dir, "backward.dot"), "w") as f:
          f.write(backward_computation.GetHloDotGraph())

    # Save state
    _save_replicated(opt_state, step, history, n_devices, output_dir, False)

    # Save Gin config
    # Gin only tracks the used parameters, so we save it after the first epoch.
    if epoch == 1:
      save_gin(output_dir, train_sw)

    # Update learning rate with new history
    old_lr_fn = lr_fn
    lr_fn = lr_schedule(history)
    if lr_fn != old_lr_fn:  # For performance, only jit if there is a change.
      opt = optimizer(lr_fn)
      jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices)

    # Flush summary writers
    train_sw.flush()
    eval_sw.flush()

  step_log(step, "Training done")
  return State(params=opt_state, step=step, history=history)
Example #14
0
  def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs,
               output_dir=None, random_seed=None, n_devices=None,
               save_steps=None, should_save=True):
    if save_steps is None:
      save_steps = []
    self._save_steps = save_steps
    self._should_save = should_save
    device_count = jax.lib.xla_bridge.device_count()
    n_devices = n_devices or device_count
    # TODO(lukaszkaiser): remove this restriction when possible.
    if n_devices != device_count:
      raise ValueError("Jax cannot work yet with n_devices != all devices: "
                       "%d != %d" % (n_devices, device_count))
    self._n_devices = n_devices
    rng = get_random_number_generator_and_set_seed(random_seed)
    inputs = inputs(n_devices)
    self._inputs = inputs

    # Initialize the learning rate to a dummy value. It will be set in reset().
    opt = optimizer(learning_rate=0.0)

    # Setup the model.
    model_train = model(mode="train")
    model_predict_eval = model(mode="eval")

    # Setup state.
    rng, init_rng = jax_random.split(rng)
    self._rngs = jax_random.split(rng, n_devices)
    first_shape = inputs.input_shape[0]
    # If the inputs are a tuple/list, add [None] (batch) to each element.
    if isinstance(first_shape, (list, tuple)):
      model_input_shape = tuple(
          tuple([None] + list(shape)) for shape in inputs.input_shape)
    else:  # Otherwise just add [None] to the input shape.
      model_input_shape = tuple([None] + list(inputs.input_shape))
    # Change all None to 1 in input shape.
    model_input_shape = layers.nested_map(
        model_input_shape, lambda x: x if x else 1)
    def initialize(input_shape, input_dtype, init_rng):
      params = model_train.initialize(input_shape, input_dtype, init_rng)
      (slots, opt_params) = opt.tree_init(params)
      return OptState(params, slots, opt_params)
    if _is_jit_init():
      # JIT parameter initialization to avoid memory fragmentation
      initialize = backend.jit(initialize, static_argnums=(0, 1))
    self._initialize = lambda: initialize(  # pylint: disable=g-long-lambda
        model_input_shape, self._inputs.input_dtype, init_rng)

    # jit model_predict and update so they're fast
    self._jit_model_predict_eval = _jit_predict_fn(
        model_predict_eval, n_devices)
    self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices)

    self._model_train = model_train
    self._model_predict_eval = model_predict_eval
    self._loss_fn = loss_fn
    self._lr_schedule = lr_schedule

    # Those fields will be set in reset().
    self._output_dir = None
    self._train_sw = None
    self._eval_sw = None
    self._history = None
    self._lr_fn = None
    self._opt_state = None
    self._step = None

    if output_dir is not None:
      self.reset(output_dir)
Example #15
0
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
                 output_dir=None,
                 random_seed=None,
                 n_devices=None,
                 save_steps=None,
                 should_save=True,
                 has_weights=False,
                 nontrainable_param_map=None,
                 mask_id=None):
        if save_steps is None:
            save_steps = []
        self._save_steps = save_steps
        self._should_save = should_save
        self._has_weights = has_weights
        self._mask_id = mask_id
        loss_fn = loss_fn(has_weights=has_weights, mask_id=mask_id)
        device_count = jax.lib.xla_bridge.device_count()
        n_devices = n_devices or device_count
        # TODO(lukaszkaiser): remove this restriction when possible.
        if n_devices != device_count:
            raise ValueError(
                "Jax cannot work yet with n_devices != all devices: "
                "%d != %d" % (n_devices, device_count))
        self._n_devices = n_devices
        rng = get_random_number_generator_and_set_seed(random_seed)
        inputs = inputs(n_devices)
        self._inputs = inputs

        # Initialize the learning rate to a dummy value. It will be set in reset().
        opt = optimizer(learning_rate=0.0)

        # Setup the model.
        model_train = model(mode="train")
        model_predict_eval = model(mode="eval")

        # Setup state.
        rng, init_rng = jax_random.split(rng)
        self._rngs = jax_random.split(rng, n_devices)
        first_shape = inputs.input_shape[0]
        # If the inputs are a tuple/list, add [None] (batch) to each element.
        if isinstance(first_shape, (list, tuple)):
            model_input_shape = tuple(
                tuple([None] + list(shape)) for shape in inputs.input_shape)
            model_target_shape = tuple(
                tuple([None] + list(shape)) for shape in inputs.target_shape)
        else:  # Otherwise just add [None] to the input shape.
            model_input_shape = tuple([None] + list(inputs.input_shape))
            model_target_shape = tuple([None] + list(inputs.target_shape))
        # Change all None to 1 in input and target shape.
        model_input_shape = layers.nested_map(model_input_shape, lambda x: x
                                              if x else 1)
        model_target_shape = layers.nested_map(model_target_shape, lambda x: x
                                               if x else 1)

        def new_opt_state_and_model_state(input_shape, input_dtype,
                                          target_shape, target_dtype, rng):
            """Returns optimizer and model states suitable for training a model."""
            # Combine inputs and targets on the stack.
            if not isinstance(input_dtype, (list, tuple)):
                input_dtype = [input_dtype]
                input_shape = [input_shape]
            if not isinstance(target_dtype, (list, tuple)):
                target_dtype = [target_dtype]
                target_shape = [target_shape]
            full_type = list(input_dtype) + list(target_dtype)
            full_shape = list(input_shape) + list(target_shape)
            if self._has_weights:
                full_shape += list(target_shape)
                full_type += [np.float32 for _ in target_dtype]
            # We need to create a new model instance and not reuse `model_train` here,
            # because `m.initialize` puts cached parameter values in `m` and hence the
            # next call of `m.initialize` will give wrong results.
            m = layers.Serial([model(mode="train"), loss_fn])
            params, state = m.initialize_once(full_shape, full_type, rng)
            (slots, opt_params) = opt.tree_init(params)
            return (OptState(params, slots, opt_params), state)

        if _is_jit_init():
            # JIT parameter initialization to avoid memory fragmentation
            new_opt_state_and_model_state = backend.jit(
                new_opt_state_and_model_state, static_argnums=(0, 1, 2, 3))
        self._new_opt_state_and_model_state = (
            lambda: new_opt_state_and_model_state(  # pylint: disable=g-long-lambda
                model_input_shape, self._inputs.input_dtype,
                model_target_shape, self._inputs.target_dtype, init_rng))

        # jit model_predict and update so they're fast
        # TODO(lukaszkaiser): the code below creates a layer computing
        # multiple metrics from a single model output; re-factor for clarity.
        dup_layer = layers.Dup3() if self._has_weights else layers.Dup2()

        def lower(layer):
            """Apply layer below the current inputs, targets, and possibly weights."""
            if self._has_weights:
                # Apply layer below inputs, targets, and loss weights.
                return layers.Parallel([], [], [], layer)
            else:
                # Apply layer below inputs and targets.
                return layers.Parallel([], [], layer)

        metrics_layer = []
        self._metrics = list(sorted(_METRICS.keys()))
        for i, m in enumerate(reversed(self._metrics)):
            metric = _METRICS[m](has_weights=self._has_weights,
                                 mask_id=self._mask_id)
            if i != len(self._metrics) - 1:
                metrics_layer.append(dup_layer)
                metrics_layer.append(lower(metric))
            else:
                metrics_layer.append(metric)
        # TODO(lukaszkaiser): clean this up once layer API stabilizes.
        # For now, we need to initialize metric layers somehow, so here we go.
        # We assume that they do not have any parameters, so this is a dummy.
        dummy_shape = ((1, 2), (1, ), (1, )) if self._has_weights else ((1, 2),
                                                                        (1, ))
        dummy_type = [np.float32] * (3 if self._has_weights else 2)
        metrics_layer = layers.Serial(metrics_layer)
        metrics_params, metrics_state = metrics_layer.initialize_once(
            dummy_shape, tuple(dummy_type), init_rng)
        self._metrics_params = layers.nested_map(metrics_params,
                                                 self._maybe_replicate)
        self._metrics_state = layers.nested_map(metrics_state,
                                                self._maybe_replicate)
        self._jit_eval = _jit_predict_fn(model_predict_eval, metrics_layer,
                                         n_devices)
        self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt,
                                             n_devices)

        self._model_train = model_train
        self._model_predict_eval = model_predict_eval
        self._loss_fn = loss_fn
        # TODO(pkozakowski): "Learning rate schedules" are currently able to control
        # control all optimizer parameters and model state, so let's rename them
        # accordingly.
        self._lr_schedule = lr_schedule

        if nontrainable_param_map is None:
            nontrainable_param_map = {}
        self._nontrainable_param_map = nontrainable_param_map

        # Those fields will be set in reset().
        self._output_dir = None
        self._train_sw = None
        self._eval_sw = None
        self._history = None
        self._lr_fn = None
        self._opt_state = None
        self._step = None
        self._model_state = None

        if output_dir is not None:
            self.reset(output_dir)