def save_computation_graphs(self, save_backward_graph):
   """Dump computation graphs to files."""
   if self._n_devices != 1:
     return  # TODO(lukaszkaiser): make this work with more devices.
   next_train_batch = next(self._train_stream)
   output_dir = self._output_dir
   if self._n_devices > 1:
     next_train_batch = backend.reshape_by_device(
         next_train_batch, self._n_devices)
   weights = self._opt_state[0][0]
   forward_computation = jax.xla_computation(self._model_predict_eval)(
       next_train_batch, weights=weights, state=self._model_state[0],
       rng=self._rngs[0])
   with gfile.GFile(os.path.join(output_dir, 'forward.txt'), 'w') as f:
     f.write(forward_computation.GetHloText())
   with gfile.GFile(os.path.join(output_dir, 'forward.dot'), 'w') as f:
     f.write(forward_computation.GetHloDotGraph())
   backward_computation = jax.xla_computation(self._jit_update_fn)(
       self._step, self._opt_state, next_train_batch, self._model_state,
       self._rngs)
   with gfile.GFile(os.path.join(output_dir, 'backward.txt'), 'w') as f:
     f.write(backward_computation.GetHloText())
   if save_backward_graph:  # Backward graphs can be large so we guard it.
     with gfile.GFile(os.path.join(output_dir, 'backward.dot'), 'w') as f:
       f.write(backward_computation.GetHloDotGraph())
  def train_epoch(self, epoch_steps, eval_steps):
    """Runs the trainer for `epoch_steps` steps."""
    print()  # Add visual separator in logs for start of training epoch.
    start_time = time.time()

    for _ in range(epoch_steps):
      next_train_batch = next(self._train_stream)
      if self._n_devices > 1:  # TODO(lukaszkaiser): use everywhere if possible.
        next_train_batch = backend.reshape_by_device(
            next_train_batch, self._n_devices)
      self._train_step(next_train_batch)

      # Occasionally save state, and occasionally log nontrainable weights
      # (e.g., learning rate, dropout).
      if self._step in self._save_steps and self.is_chief:
        self._maybe_save_state(keep=True)
      if (self._step == 1 or self._step % 10 == 0) and self._train_sw:
        for (name, value) in self.nontrainable_params.items():
          self._train_sw.scalar('training/{}'.format(name), value)

    # At end of epoch, do bookkeeping, run evals, and save state.
    epoch_time = time.time() - start_time
    step_log(self._step, 'Ran %d train steps in %0.2f secs' %
             (epoch_steps, epoch_time))
    if epoch_steps > 1 and self._train_sw:
      self._train_sw.scalar('training/steps per second',
                            epoch_steps / epoch_time, step=self._step)
    self.evaluate(eval_steps)
    if self.is_chief:
      self._maybe_save_state(keep=False)
    if self._train_sw:
      self._train_sw.flush()
      self._eval_sw.flush()
 def predict(x, weights, state, rng):
   """Predict function jited and parallelized as requested."""
   res, state = backend.combine_devices(model_predict(
       backend.reshape_by_device(x, n_devices),
       weights,
       state,
       np.stack(jax_random.split(rng, n_devices))))
   return layers.nested_map(lambda y: np.mean(y, axis=0), res), state
 def compute_loss(opt_state, batch, state, rng):
   return mapped_compute_loss(
       opt_state, backend.reshape_by_device(batch, n_devices), state, rng)