def train(output_dir, model=gin.REQUIRED, loss_fn=loss, inputs=trax_inputs.inputs, optimizer=trax_opt.SM3, lr_schedule=lr.MultifactorSchedule, train_steps=1000, save_steps=None, eval_steps=10, eval_frequency=100, n_devices=None, random_seed=None, save_graphs=True, save_backward_graph=False): """Train the model on the inputs. Args: output_dir: Directory where to put the logs and checkpoints. model: The model to train as a callable returning 2 callables, an init_fn and apply_fn. loss_fn: callable with signature: params, trax.inputs.Inputs, model, rng -> loss. inputs: callable returning trax.inputs.Inputs. optimizer: The optimizer (see optimizers/base.py for signature). lr_schedule: A learning rate schedule as a function that takes history and returns a function from step to learning rate (a float). train_steps: int, total number of training steps. save_steps: list of integers. Keep a model file at each of the supplied save steps. eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. eval_frequency: int, how often to run evaluation (every eval_frequency steps). If None or 0, eval disabled. n_devices: how many devices to use (if None, default, use all available) random_seed: the random seed to use; time/os dependent if None (default). save_graphs: bool, if True, save computation graph to file. save_backward_graph: bool, if True, save backward graph to file too. Returns: trax.State """ trainer = Trainer(model, loss_fn, optimizer, lr_schedule, inputs, output_dir, random_seed=random_seed, n_devices=n_devices, save_steps=save_steps) epoch_steps = [train_steps] # Only training if eval_frequency is 0 or None if eval_frequency and eval_steps > 0: epoch_steps = itertools.chain([1, # first epoch only 1 step eval_frequency - 1], itertools.repeat(eval_frequency)) step_log(trainer.step, "Starting training using %d devices" % trainer.n_devices) for _, epoch_steps in epochs(train_steps, epoch_steps): trainer.train_epoch(epoch_steps, eval_steps) # Update learning rate with new history trainer.update_learning_rate() # Bookkeeping we do at the first step if trainer.step == 1: # Print number of parameters trainer.print_n_params() # Save computation graph (single-device only for now) if (save_graphs and backend.get_name() == "jax"): trainer.save_computation_graphs(save_backward_graph) # Save Gin config trainer.save_gin() step_log(trainer.step, "Training done") return trainer.state
def __call__(self, x, params=(), state=(), **kwargs): try: # If params are nothing, we may be reusing this layer. # Use the cached parameters to calculate the value. # Note: to make sure jit tracers can decide this branch in python we # use "params is ()" instead of, e.g., "not params" or "params == ()". if params is (): # pylint: disable=literal-comparison params = self._params else: # In this case, we're called for the first time: cache parameters. self._params = params if not self.has_custom_grad: return self.call(x, params=params, state=state, **kwargs) # Custom gradients part. assert backend.get_name() == 'jax', ( 'Custom gradients are only supported in JAX for now.') # TODO(wangpeng): JAX doesn't support custom grads for functions with # auxiliary output yet (https://github.com/google/jax/issues/844). Will # remove the constraints on state below when this feature is added to # JAX. assert not jax.tree_util.tree_leaves(state), ( 'Custom gradients require trivial start state. Got %s' % str(state)) def check_end_state(output_state): output, state = output_state assert not jax.tree_util.tree_leaves(state), ( 'Custom gradients require trivial end state. Got %s' % str(state)) return output # See this link for how custom transformations are defined in JAX: # https://jax.readthedocs.io/en/latest/jax.html#jax.custom_transforms # Note that we capture the kwargs and don't calculate gradients wrt. them. @jax.custom_transforms def do_call(y, params): return check_end_state( self.call(y, params=params, state=state, **kwargs)) # This is the custom gradient (vector-jacobian product in JAX) function. # For the exact specification of this custom transformation see this link: # https://jax.readthedocs.io/en/latest/jax.html#jax.defjvp_all def do_call_vjp(y, params): output = check_end_state( self.call(y, params=params, state=state, **kwargs)) def vjpfun(grad): return self.custom_grad(y, output, grad, params, state, **kwargs) return output, vjpfun jax.defvjp_all(do_call, do_call_vjp) return do_call(x, params), state except Exception: name, trace = self.__class__.__name__, _short_traceback() raise LayerError(name, 'call', self._caller, shapes(x), trace)
def train(output_dir, model=gin.REQUIRED, loss_fn=loss, inputs=trax_inputs.inputs, optimizer=trax_opt.SM3, lr_schedule=lr.MultifactorSchedule, train_steps=1000, save_steps=None, eval_steps=10, eval_frequency=100, n_devices=None, random_seed=None, run_debug_step=False, save_graphs=True, save_backward_graph=False): """Train the model on the inputs. Args: output_dir: Directory where to put the logs and checkpoints. model: The model to train as a callable returning 2 callables, an init_fn and apply_fn. loss_fn: callable with signature: params, trax.inputs.Inputs, model, rng -> loss. inputs: callable returning trax.inputs.Inputs. optimizer: The optimizer (see optimizers/base.py for signature). lr_schedule: A learning rate schedule as a function that takes history and returns a function from step to learning rate (a float). train_steps: int, total number of training steps. save_steps: list of integers. Keep a model file at each of the supplied save steps. eval_steps: int, num of steps per evaluation. If None or 0, eval disabled. eval_frequency: int, how often to run evaluation (every eval_frequency steps). If None or 0, eval disabled. n_devices: how many devices to use (if None, default, use all available) random_seed: the random seed to use; time/os dependent if None (default). run_debug_step: bool, if True, will run the model and loss without @jit for one step. save_graphs: bool, if True, save computation graph to file. save_backward_graph: bool, if True, save backward graph to file too. Returns: trax.State """ if save_steps is None: save_steps = [] device_count = jax.lib.xla_bridge.device_count() n_devices = n_devices or device_count # TODO(lukaszkaiser): remove this restriction when possible. if n_devices != device_count: raise ValueError("Jax cannot work yet with n_devices != all devices: " "%d != %d" % (n_devices, device_count)) rng = get_random_number_generator_and_set_seed(random_seed) gfile.makedirs(output_dir) # Create summary writers and history. train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train")) eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval")) inputs = inputs(n_devices) # Setup optimizer and model state = restore_state(output_dir) history = state.history lr_fn = lr_schedule(history) opt = optimizer(lr_fn) model_train = model(mode="train") model_predict_eval = model(mode="eval") # Setup state step = state.step or 0 rng, init_rng = jax_random.split(rng) rngs = jax_random.split(rng, n_devices) first_shape = inputs.input_shape[0] # If the inputs are a tuple/list, add [None] (batch) to each element. if isinstance(first_shape, (list, tuple)): model_input_shape = tuple( tuple([None] + list(shape)) for shape in inputs.input_shape) else: # Otherwise just add [None] to the input shape. model_input_shape = tuple([None] + list(inputs.input_shape)) # Change all None to 1 in input shape. model_input_shape = layers.nested_map( model_input_shape, lambda x: x if x else 1) if state.params: params = state.params[0] opt_state = state.params else: params = model_train.initialize( model_input_shape, inputs.input_dtype, init_rng) opt_state = (params, opt.tree_init(params)) if n_devices > 1: replicate = lambda x: numpy.broadcast_to(x, (n_devices,) + x.shape) opt_state = layers.nested_map(opt_state, replicate) # jit model_predict and update so they're fast jit_model_predict_eval = _jit_predict_fn(model_predict_eval, n_devices) jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices) train_stream = inputs.train_stream() epoch_steps = [train_steps] # Only training if eval_frequency is 0 or None. if eval_frequency and eval_steps > 0: epoch_steps = itertools.chain([1, # first epoch only 1 step eval_frequency - 1], itertools.repeat(eval_frequency)) step_log(step, "Starting training using %d devices" % n_devices) # Non-compiled debug step helps find problems in models easier. if run_debug_step: debug_loss = loss_fn(params, next(train_stream), model_train, rng) step_log(step, "Debug step loss %.8f" % debug_loss) for epoch, epoch_steps in epochs(train_steps, epoch_steps): # Log separator print() # Timer start_time = time.time() for _ in range(epoch_steps): # Train next_train_batch = next(train_stream) if n_devices > 1: # TODO(lukaszkaiser): use everywhere when possible. next_train_batch = reshape_by_device(next_train_batch, n_devices) opt_state, rngs = jit_update_fn(step, opt_state, next_train_batch, rngs) step += 1 if step in save_steps: _save_replicated(opt_state, step, history, n_devices, output_dir, True) # LR log if step == 1 or step % 10 == 0: train_sw.scalar("training/learning rate", lr_fn(step), step=step) # Timer epoch_time = time.time() - start_time step_log(step, "Ran %d train steps in %0.2f secs" % (epoch_steps, epoch_time)) if epoch_steps > 1: train_sw.scalar("training/steps per second", epoch_steps / epoch_time, step=step) # Print number of parameters if step == 1: sizes = layers.sizes(opt_state[0]) if n_devices > 1: unreplicate = lambda x: x.mean(0) single_params = layers.nested_map(opt_state[0], unreplicate) sizes = layers.sizes(single_params) total_size = layers.nested_reduce(sizes, sum) step_log(step, "Total trainable parameters size: %d" % total_size) # Evaluate in parallel evaluate_train_and_eval( step=step, inputs=inputs, predict_fn=functools.partial(jit_model_predict_eval, params=opt_state[0]), eval_steps=eval_steps, rng=rng, train_sw=train_sw, eval_sw=eval_sw, history=history) # Save computation graph (single-device only for now). if (save_graphs and backend.get_name() == "jax" and step == 1 and n_devices == 1): params = opt_state[0] # Dump computation graphs to files. forward_computation = jax.xla_computation(model_predict_eval)( next_train_batch[0], params=params, rng=rng) with gfile.GFile(os.path.join(output_dir, "forward.txt"), "w") as f: f.write(forward_computation.GetHloText()) with gfile.GFile(os.path.join(output_dir, "forward.dot"), "w") as f: f.write(forward_computation.GetHloDotGraph()) backward_computation = jax.xla_computation(jit_update_fn)( step, opt_state, next_train_batch, rngs) with gfile.GFile(os.path.join(output_dir, "backward.txt"), "w") as f: f.write(backward_computation.GetHloText()) if save_backward_graph: # Backward graphs can be large so we guard it. with gfile.GFile(os.path.join(output_dir, "backward.dot"), "w") as f: f.write(backward_computation.GetHloDotGraph()) # Save state _save_replicated(opt_state, step, history, n_devices, output_dir, False) # Save Gin config # Gin only tracks the used parameters, so we save it after the first epoch. if epoch == 1: save_gin(output_dir, train_sw) # Update learning rate with new history old_lr_fn = lr_fn lr_fn = lr_schedule(history) if lr_fn != old_lr_fn: # For performance, only jit if there is a change. opt = optimizer(lr_fn) jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices) # Flush summary writers train_sw.flush() eval_sw.flush() step_log(step, "Training done") return State(params=opt_state, step=step, history=history)