Ejemplo n.º 1
0
def _print_n_params(opt_state, n_devices, step):
    """Print out the number of parameters."""
    sizes = layers.sizes(opt_state.params)
    if n_devices > 1:
        unreplicate = lambda x: x[0]
        single_params = layers.nested_map(opt_state.params, unreplicate)
        sizes = layers.sizes(single_params)
    total_size = layers.nested_reduce(sizes, sum)
    step_log(step, "Total trainable parameters size: %d" % total_size)
Ejemplo n.º 2
0
def train(output_dir,
          model=gin.REQUIRED,
          loss_fn=loss,
          inputs=trax_inputs.inputs,
          optimizer=trax_opt.SM3,
          lr_schedule=lr.MultifactorSchedule,
          train_steps=1000,
          save_steps=None,
          eval_steps=10,
          eval_frequency=100,
          n_devices=None,
          random_seed=None,
          run_debug_step=False,
          save_graphs=True,
          save_backward_graph=False):
  """Train the model on the inputs.

  Args:
    output_dir: Directory where to put the logs and checkpoints.
    model: The model to train as a callable returning 2 callables, an init_fn
      and apply_fn.
    loss_fn: callable with signature: params, trax.inputs.Inputs, model, rng
      -> loss.
    inputs: callable returning trax.inputs.Inputs.
    optimizer: The optimizer (see optimizers/base.py for signature).
    lr_schedule: A learning rate schedule as a function that takes history and
      returns a function from step to learning rate (a float).
    train_steps: int, total number of training steps.
    save_steps: list of integers. Keep a model file at each of the supplied save
      steps.
    eval_steps: int, num of steps per evaluation. If None or 0, eval disabled.
    eval_frequency: int, how often to run evaluation (every eval_frequency
      steps). If None or 0, eval disabled.
    n_devices: how many devices to use (if None, default, use all available)
    random_seed: the random seed to use; time/os dependent if None (default).
    run_debug_step: bool, if True, will run the model and loss without @jit for
      one step.
    save_graphs: bool, if True, save computation graph to file.
    save_backward_graph: bool, if True, save backward graph to file too.
  Returns:
    trax.State
  """
  if save_steps is None:
    save_steps = []
  device_count = jax.lib.xla_bridge.device_count()
  n_devices = n_devices or device_count
  # TODO(lukaszkaiser): remove this restriction when possible.
  if n_devices != device_count:
    raise ValueError("Jax cannot work yet with n_devices != all devices: "
                     "%d != %d" % (n_devices, device_count))
  rng = get_random_number_generator_and_set_seed(random_seed)
  gfile.makedirs(output_dir)
  # Create summary writers and history.
  train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train"))
  eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval"))

  inputs = inputs(n_devices)

  # Setup optimizer and model
  state = restore_state(output_dir)
  history = state.history
  lr_fn = lr_schedule(history)
  opt = optimizer(lr_fn)

  model_train = layers.Serial(model(mode="train"))
  model_predict_eval = layers.Serial(model(mode="eval"))

  # Setup state
  step = state.step or 0
  rng, init_rng = jax_random.split(rng)
  rngs = jax_random.split(rng, n_devices)
  first_shape = inputs.input_shape[0]
  # If the inputs are a tuple/list, add [-1] (batch) to each element.
  if isinstance(first_shape, (list, tuple)):
    model_input_shape = tuple(
        [tuple([-1] + list(shape)) for shape in inputs.input_shape])
  else:  # Otherwise just add [-1] to the input shape.
    model_input_shape = tuple([-1] + list(inputs.input_shape))
  if state.params:
    params = state.params[0]
    opt_state = state.params
  else:
    params = model_train.initialize(model_input_shape, init_rng)
    opt_state = (params, opt.tree_init(params))
  if n_devices > 1:
    replicate = lambda x: numpy.broadcast_to(x, (n_devices,) + x.shape)
    opt_state = layers.nested_map(opt_state, replicate)

  # jit model_predict and update so they're fast
  jit_model_predict_eval = _jit_predict_fn(model_predict_eval, n_devices)
  jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices)

  train_stream = inputs.train_stream()
  epoch_steps = [train_steps]  # Only training if eval_frequency is 0 or None.
  if eval_frequency and eval_steps > 0:
    epoch_steps = itertools.chain([1,  # first epoch only 1 step
                                   eval_frequency - 1],
                                  itertools.repeat(eval_frequency))
  step_log(step, "Starting training using %d devices" % n_devices)

  # Non-compiled debug step helps find problems in models easier.
  if run_debug_step:
    debug_loss = loss_fn(params, next(train_stream), model_train, rng)
    step_log(step, "Debug step loss %.8f" % debug_loss)

  for epoch, epoch_steps in epochs(train_steps, epoch_steps):
    # Log separator
    print()

    # Timer
    start_time = time.time()

    for _ in range(epoch_steps):
      # Train
      next_train_batch = next(train_stream)
      if n_devices > 1:  # TODO(lukaszkaiser): use everywhere when possible.
        next_train_batch = reshape_by_device(next_train_batch, n_devices)
      opt_state, rngs = jit_update_fn(step, opt_state, next_train_batch, rngs)
      step += 1

      if step in save_steps:
        _save_replicated(opt_state, step, history, n_devices, output_dir, True)

      # LR log
      if step == 1 or step % 10 == 0:
        train_sw.scalar("training/learning rate",
                        lr_fn(step), step=step)

    # Timer
    epoch_time = time.time() - start_time
    step_log(step, "Ran %d train steps in %0.2f secs" %
             (epoch_steps, epoch_time))
    if epoch_steps > 1:
      train_sw.scalar("training/steps per second",
                      epoch_steps / epoch_time, step=step)

    # Print number of parameters
    if step == 1:
      sizes = layers.sizes(opt_state[0])
      if n_devices > 1:
        unreplicate = lambda x: x.mean(0)
        single_params = layers.nested_map(opt_state[0], unreplicate)
        sizes = layers.sizes(single_params)
      total_size = layers.nested_reduce(sizes, sum)
      step_log(step, "Total trainable parameters size: %d" % total_size)

    # Evaluate in parallel
    evaluate_train_and_eval(
        step=step,
        inputs=inputs,
        predict_fn=functools.partial(jit_model_predict_eval,
                                     params=opt_state[0]),
        eval_steps=eval_steps,
        rng=rng,
        train_sw=train_sw,
        eval_sw=eval_sw,
        history=history)

    # Save computation graph (single-device only for now).
    if save_graphs and step == 1 and n_devices == 1:
      params = opt_state[0]
      # Dump computation graphs to files.
      forward_computation = jax.xla_computation(model_predict_eval)(
          next_train_batch[0], params=params, rng=rng)
      with gfile.GFile(os.path.join(output_dir, "forward.txt"), "w") as f:
        f.write(forward_computation.GetHloText())
      with gfile.GFile(os.path.join(output_dir, "forward.dot"), "w") as f:
        f.write(forward_computation.GetHloDotGraph())
      backward_computation = jax.xla_computation(jit_update_fn)(
          step, opt_state, next_train_batch, rngs)
      with gfile.GFile(os.path.join(output_dir, "backward.txt"), "w") as f:
        f.write(backward_computation.GetHloText())
      if save_backward_graph:  # Backward graphs can be large so we guard it.
        with gfile.GFile(os.path.join(output_dir, "backward.dot"), "w") as f:
          f.write(backward_computation.GetHloDotGraph())

    # Save state
    _save_replicated(opt_state, step, history, n_devices, output_dir, False)

    # Save Gin config
    # Gin only tracks the used parameters, so we save it after the first epoch.
    if epoch == 1:
      save_gin(output_dir, train_sw)

    # Update learning rate with new history
    old_lr_fn = lr_fn
    lr_fn = lr_schedule(history)
    if lr_fn != old_lr_fn:  # For performance, only jit if there is a change.
      opt = optimizer(lr_fn)
      jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices)

    # Flush summary writers
    train_sw.flush()
    eval_sw.flush()

  step_log(step, "Training done")
  return State(params=opt_state, step=step, history=history)