def _print_n_params(opt_state, n_devices, step): """Print out the number of parameters.""" sizes = layers.sizes(opt_state.params) if n_devices > 1: unreplicate = lambda x: x[0] single_params = layers.nested_map(opt_state.params, unreplicate) sizes = layers.sizes(single_params) total_size = layers.nested_reduce(sizes, sum) step_log(step, "Total trainable parameters size: %d" % total_size)
def train(output_dir, model=gin.REQUIRED, loss_fn=loss, inputs=trax_inputs.inputs, optimizer=trax_opt.SM3, lr_schedule=lr.MultifactorSchedule, train_steps=1000, save_steps=None, eval_steps=10, eval_frequency=100, n_devices=None, random_seed=None, run_debug_step=False, save_graphs=True, save_backward_graph=False): """Train the model on the inputs. Args: output_dir: Directory where to put the logs and checkpoints. model: The model to train as a callable returning 2 callables, an init_fn and apply_fn. loss_fn: callable with signature: params, trax.inputs.Inputs, model, rng -> loss. inputs: callable returning trax.inputs.Inputs. optimizer: The optimizer (see optimizers/base.py for signature). lr_schedule: A learning rate schedule as a function that takes history and returns a function from step to learning rate (a float). train_steps: int, total number of training steps. save_steps: list of integers. Keep a model file at each of the supplied save steps. eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. eval_frequency: int, how often to run evaluation (every eval_frequency steps). If None or 0, eval disabled. n_devices: how many devices to use (if None, default, use all available) random_seed: the random seed to use; time/os dependent if None (default). run_debug_step: bool, if True, will run the model and loss without @jit for one step. save_graphs: bool, if True, save computation graph to file. save_backward_graph: bool, if True, save backward graph to file too. Returns: trax.State """ if save_steps is None: save_steps = [] device_count = jax.lib.xla_bridge.device_count() n_devices = n_devices or device_count # TODO(lukaszkaiser): remove this restriction when possible. if n_devices != device_count: raise ValueError("Jax cannot work yet with n_devices != all devices: " "%d != %d" % (n_devices, device_count)) rng = get_random_number_generator_and_set_seed(random_seed) gfile.makedirs(output_dir) # Create summary writers and history. train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train")) eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval")) inputs = inputs(n_devices) # Setup optimizer and model state = restore_state(output_dir) history = state.history lr_fn = lr_schedule(history) opt = optimizer(lr_fn) model_train = layers.Serial(model(mode="train")) model_predict_eval = layers.Serial(model(mode="eval")) # Setup state step = state.step or 0 rng, init_rng = jax_random.split(rng) rngs = jax_random.split(rng, n_devices) first_shape = inputs.input_shape[0] # If the inputs are a tuple/list, add [-1] (batch) to each element. if isinstance(first_shape, (list, tuple)): model_input_shape = tuple( [tuple([-1] + list(shape)) for shape in inputs.input_shape]) else: # Otherwise just add [-1] to the input shape. model_input_shape = tuple([-1] + list(inputs.input_shape)) if state.params: params = state.params[0] opt_state = state.params else: params = model_train.initialize(model_input_shape, init_rng) opt_state = (params, opt.tree_init(params)) if n_devices > 1: replicate = lambda x: numpy.broadcast_to(x, (n_devices,) + x.shape) opt_state = layers.nested_map(opt_state, replicate) # jit model_predict and update so they're fast jit_model_predict_eval = _jit_predict_fn(model_predict_eval, n_devices) jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices) train_stream = inputs.train_stream() epoch_steps = [train_steps] # Only training if eval_frequency is 0 or None. if eval_frequency and eval_steps > 0: epoch_steps = itertools.chain([1, # first epoch only 1 step eval_frequency - 1], itertools.repeat(eval_frequency)) step_log(step, "Starting training using %d devices" % n_devices) # Non-compiled debug step helps find problems in models easier. if run_debug_step: debug_loss = loss_fn(params, next(train_stream), model_train, rng) step_log(step, "Debug step loss %.8f" % debug_loss) for epoch, epoch_steps in epochs(train_steps, epoch_steps): # Log separator print() # Timer start_time = time.time() for _ in range(epoch_steps): # Train next_train_batch = next(train_stream) if n_devices > 1: # TODO(lukaszkaiser): use everywhere when possible. next_train_batch = reshape_by_device(next_train_batch, n_devices) opt_state, rngs = jit_update_fn(step, opt_state, next_train_batch, rngs) step += 1 if step in save_steps: _save_replicated(opt_state, step, history, n_devices, output_dir, True) # LR log if step == 1 or step % 10 == 0: train_sw.scalar("training/learning rate", lr_fn(step), step=step) # Timer epoch_time = time.time() - start_time step_log(step, "Ran %d train steps in %0.2f secs" % (epoch_steps, epoch_time)) if epoch_steps > 1: train_sw.scalar("training/steps per second", epoch_steps / epoch_time, step=step) # Print number of parameters if step == 1: sizes = layers.sizes(opt_state[0]) if n_devices > 1: unreplicate = lambda x: x.mean(0) single_params = layers.nested_map(opt_state[0], unreplicate) sizes = layers.sizes(single_params) total_size = layers.nested_reduce(sizes, sum) step_log(step, "Total trainable parameters size: %d" % total_size) # Evaluate in parallel evaluate_train_and_eval( step=step, inputs=inputs, predict_fn=functools.partial(jit_model_predict_eval, params=opt_state[0]), eval_steps=eval_steps, rng=rng, train_sw=train_sw, eval_sw=eval_sw, history=history) # Save computation graph (single-device only for now). if save_graphs and step == 1 and n_devices == 1: params = opt_state[0] # Dump computation graphs to files. forward_computation = jax.xla_computation(model_predict_eval)( next_train_batch[0], params=params, rng=rng) with gfile.GFile(os.path.join(output_dir, "forward.txt"), "w") as f: f.write(forward_computation.GetHloText()) with gfile.GFile(os.path.join(output_dir, "forward.dot"), "w") as f: f.write(forward_computation.GetHloDotGraph()) backward_computation = jax.xla_computation(jit_update_fn)( step, opt_state, next_train_batch, rngs) with gfile.GFile(os.path.join(output_dir, "backward.txt"), "w") as f: f.write(backward_computation.GetHloText()) if save_backward_graph: # Backward graphs can be large so we guard it. with gfile.GFile(os.path.join(output_dir, "backward.dot"), "w") as f: f.write(backward_computation.GetHloDotGraph()) # Save state _save_replicated(opt_state, step, history, n_devices, output_dir, False) # Save Gin config # Gin only tracks the used parameters, so we save it after the first epoch. if epoch == 1: save_gin(output_dir, train_sw) # Update learning rate with new history old_lr_fn = lr_fn lr_fn = lr_schedule(history) if lr_fn != old_lr_fn: # For performance, only jit if there is a change. opt = optimizer(lr_fn) jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices) # Flush summary writers train_sw.flush() eval_sw.flush() step_log(step, "Training done") return State(params=opt_state, step=step, history=history)