def save_computation_graphs(self, save_backward_graph): """Dump computation graphs to files.""" if self._n_devices != 1: return # TODO(lukaszkaiser): make this work with more devices. next_train_batch = next(self._train_stream) output_dir = self._output_dir if self._n_devices > 1: next_train_batch = backend.reshape_by_device( next_train_batch, self._n_devices) weights = self._opt_state[0][0] forward_computation = jax.xla_computation(self._model_predict_eval)( next_train_batch, weights=weights, state=self._model_state[0], rng=self._rngs[0]) with gfile.GFile(os.path.join(output_dir, 'forward.txt'), 'w') as f: f.write(forward_computation.GetHloText()) with gfile.GFile(os.path.join(output_dir, 'forward.dot'), 'w') as f: f.write(forward_computation.GetHloDotGraph()) backward_computation = jax.xla_computation(self._jit_update_fn)( self._step, self._opt_state, next_train_batch, self._model_state, self._rngs) with gfile.GFile(os.path.join(output_dir, 'backward.txt'), 'w') as f: f.write(backward_computation.GetHloText()) if save_backward_graph: # Backward graphs can be large so we guard it. with gfile.GFile(os.path.join(output_dir, 'backward.dot'), 'w') as f: f.write(backward_computation.GetHloDotGraph())
def train_epoch(self, epoch_steps, eval_steps): """Runs the trainer for `epoch_steps` steps.""" print() # Add visual separator in logs for start of training epoch. start_time = time.time() for _ in range(epoch_steps): next_train_batch = next(self._train_stream) if self._n_devices > 1: # TODO(lukaszkaiser): use everywhere if possible. next_train_batch = backend.reshape_by_device( next_train_batch, self._n_devices) self._train_step(next_train_batch) # Occasionally save state, and occasionally log nontrainable weights # (e.g., learning rate, dropout). if self._step in self._save_steps and self.is_chief: self._maybe_save_state(keep=True) if (self._step == 1 or self._step % 10 == 0) and self._train_sw: for (name, value) in self.nontrainable_params.items(): self._train_sw.scalar('training/{}'.format(name), value) # At end of epoch, do bookkeeping, run evals, and save state. epoch_time = time.time() - start_time step_log(self._step, 'Ran %d train steps in %0.2f secs' % (epoch_steps, epoch_time)) if epoch_steps > 1 and self._train_sw: self._train_sw.scalar('training/steps per second', epoch_steps / epoch_time, step=self._step) self.evaluate(eval_steps) if self.is_chief: self._maybe_save_state(keep=False) if self._train_sw: self._train_sw.flush() self._eval_sw.flush()
def predict(x, weights, state, rng): """Predict function jited and parallelized as requested.""" res, state = backend.combine_devices(model_predict( backend.reshape_by_device(x, n_devices), weights, state, np.stack(jax_random.split(rng, n_devices)))) return layers.nested_map(lambda y: np.mean(y, axis=0), res), state
def compute_loss(opt_state, batch, state, rng): return mapped_compute_loss( opt_state, backend.reshape_by_device(batch, n_devices), state, rng)