Exemplo n.º 1
0
  def predict(params, batch, rng=None):
    """Predict function jited and parallelized as requested."""
    # If not jit'ing, just run the function.
    if not jit_eval:
      return model_predict(params, batch, rng=rng)

    # On one device, jit and run.
    if num_devices == 1:
      return backend.jit(model_predict)(params, batch, rng=rng)

    # Multi-devices, pmap and run.
    @functools.partial(backend.pmap, axis_name="batch")
    def mapped_predict(params, batch, rng):
      return model_predict(params, batch, rng=rng)
    pred = mapped_predict(
        jax.replicate(params),
        reshape_by_device(batch, num_devices),
        jax.replicate(rng))
    batch_size = batch.shape[0]
    return np.reshape(pred, [batch_size] + list(pred.shape[2:]))
Exemplo n.º 2
0
def train(output_dir,
          model=gin.REQUIRED,
          loss_fun=loss,
          inputs=trax_inputs.inputs,
          optimizer=trax_opt.adam,
          lr_schedule=lr.MultifactorSchedule,
          train_steps=1000,
          eval_steps=10,
          eval_frequency=100,
          num_devices=None,
          random_seed=None,
          run_debug_step=False):
  """Train the model on the inputs.

  Args:
    output_dir: Directory where to put the logs and checkpoints.
    model: The model to train as a callable returning 2 callables, an init_fun
      and apply_fun.
    loss_fun: callable with signature: params, trax.inputs.Inputs, model, rng
      -> loss.
    inputs: callable returning trax.inputs.Inputs.
    optimizer: The optimizer as a callable taking a learning_rate callable and
      returning 2 callables, opt_init and opt_update.
    lr_schedule: A learning rate schedule as a function that takes history and
      returns a function from step to learning rate (a float).
    train_steps: int, total number of training steps.
    eval_steps: int, num of steps per evaluation. If None or 0, eval disabled.
    eval_frequency: int, how often to run evaluation (every eval_frequency
      steps). If None or 0, eval disabled.
    num_devices: how many devices to use (if None, default, use all available)
    random_seed: the random seed to use; time/os dependent if None (default).
    run_debug_step: bool, if True, will run the model and loss without @jit for
      one step.

  Returns:
    trax.State
  """
  num_devices = num_devices or jax.lib.xla_bridge.device_count()
  rng = get_random_number_generator_and_set_seed(random_seed)
  gfile.makedirs(output_dir)
  # Create summary writers and history.
  train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train"))
  eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval"))

  inputs = inputs(num_devices)

  # Setup optimizer and model
  state = restore_state(output_dir)
  history = state.history
  lr_fun = lr_schedule(history)
  opt_init, _ = optimizer(lr_fun)
  model_init, model_predict_train = model(mode="train")
  _, model_predict_eval = model(mode="eval")

  # Setup state
  step = state.step or 0
  rng, init_key = jax_random.split(rng)
  params_initializer = \
      lambda: model_init(init_key, [-1] + list(inputs.input_shape))[1]
  params = state.params or params_initializer()
  opt_state = opt_init(params)
  if num_devices > 1:  # TODO(lukaszkaiser): use everywhere when pmap is stable.
    opt_state = jax.replicate(opt_state)

  # jit model_predict and update so they're fast
  jit_model_predict_eval = _jit_predict_fun(model_predict_eval, num_devices)
  jit_update_fun = _jit_update_fun(
      model_predict_train, loss_fun, optimizer, lr_fun, num_devices)

  print()
  train_stream = inputs.train_stream()
  epoch_steps = [train_steps]  # Only training if eval_frequency is 0 or None.
  if eval_frequency and eval_steps > 0:
    epoch_steps = itertools.chain([1,  # first epoch only 1 step
                                   eval_frequency - 1],
                                  itertools.repeat(eval_frequency))
  step_log(step, "Starting training using %d devices" % num_devices)

  # Non-compiled debug step helps find problems in models easier.
  if run_debug_step:
    debug_loss = loss_fun(params, next(train_stream), model_predict_train, rng)
    step_log(step, "Debug step loss %.8f" % debug_loss)

  for epoch, epoch_steps in epochs(train_steps, epoch_steps):
    # Log separator
    print()

    # Timer
    start_time = time.time()

    for _ in range(epoch_steps):
      # Train
      next_train_batch = next(train_stream)
      if num_devices > 1:  # TODO(lukaszkaiser): use everywhere when possible.
        next_train_batch = reshape_by_device_pair(next_train_batch, num_devices)
      rng, subrng = jax_random.split(rng)
      opt_state = jit_update_fun(step, opt_state, next_train_batch, subrng)
      step += 1

      # LR log
      if step == 1 or step % 10 == 0:
        train_sw.scalar("training/learning rate",
                        lr_fun(step), step=step)

    # Timer
    epoch_time = time.time() - start_time
    step_log(step, "Ran %d train steps in %0.2f secs" %
             (epoch_steps, epoch_time))
    if epoch_steps > 1:
      train_sw.scalar("training/steps per second",
                      epoch_steps / epoch_time, step=step)

    # Evaluate
    params = trax_opt.get_params(opt_state)
    evaluate_train_and_eval(
        step=step,
        inputs=inputs,
        predict_fun=functools.partial(jit_model_predict_eval, params),
        eval_steps=eval_steps,
        rng=rng,
        train_sw=train_sw,
        eval_sw=eval_sw,
        history=history)

    # Save state
    save_state(State(params=params, step=step, history=history), output_dir)

    # Save Gin config
    # Gin only tracks the used parameters, so we save it after the first epoch.
    if epoch == 1:
      save_gin(output_dir, train_sw)

    # Update learning rate with new history
    old_lr_fun = lr_fun
    lr_fun = lr_schedule(history)
    if lr_fun != old_lr_fun:  # For performance, only jit if there is a change.
      jit_update_fun = _jit_update_fun(
          model_predict_train, loss_fun, optimizer, lr_fun, num_devices)

    # Flush summary writers
    train_sw.flush()
    eval_sw.flush()

  step_log(step, "Training done")
  return State(params=params, step=step, history=history)
Exemplo n.º 3
0
 def update(i, opt_state, batch, rng):
   # TODO(lukaszkaiser): investigate how to replicate rng and correct.
   return mapped_update(jax.replicate(i), opt_state, batch, jax.replicate(rng))
Exemplo n.º 4
0
 def update(i, opt_state, batch, rng):
     return mapped_update(jax.replicate(i), opt_state, batch, rng)
    batches = data_stream()

    @partial(pmap, axis_name='batch')
    def spmd_update(params, batch):
        grads = grad(loss)(params, batch)
        # We compute the total gradients, summing across the device-mapped axis,
        # using the `lax.psum` SPMD primitive, which does a fast all-reduce-sum.
        grads = [(lax.psum(dw, 'batch'), lax.psum(db, 'batch'))
                 for dw, db in grads]
        return [(w - step_size * dw, b - step_size * db)
                for (w, b), (dw, db) in zip(params, grads)]

    # We replicate parameters out across devices. (Check the implementation of
    # replicate; analogous to device_put, it's a simple wrapper around pmap.)
    params = replicate(init_random_params(param_scale, layer_sizes))

    for epoch in range(num_epochs):
        start_time = time.time()
        for _ in range(num_batches):
            params = spmd_update(params, next(batches))
        epoch_time = time.time() - start_time

        # We evaluate using the jitted `accuracy` function (not using pmap) by
        # grabbing just one of the replicated parameter values.
        train_acc = accuracy(unreplicate(params), (train_images, train_labels))
        test_acc = accuracy(unreplicate(params), (test_images, test_labels))
        print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
        print("Training set accuracy {}".format(train_acc))
        print("Test set accuracy {}".format(test_acc))
Exemplo n.º 6
0
def train(output_dir,
          model=gin.REQUIRED,
          inputs=gin.REQUIRED,
          optimizer=trax_opt.adam,
          lr_schedule=lr.MultifactorSchedule,
          train_steps=1000,
          eval_steps=10,
          eval_frequency=100,
          num_devices=None,
          run_debug_step=False):
  """Train the model on the inputs.

  Args:
    output_dir: Directory where to put the logs and checkpoints.
    model: The model to train as a callable returning 2 callables, an init_fun
      and apply_fun.
    inputs: callable returning trax.inputs.Inputs.
    optimizer: The optimizer as a callable taking a learning_rate callable and
      returning 2 callables, opt_init and opt_update.
    lr_schedule: A learning rate schedule as a function that takes history and
      returns a function from step to learning rate (a float).
    train_steps: int, total number of training steps.
    eval_steps: int, num of steps per evaluation. If None or 0, eval disabled.
    eval_frequency: int, how often to run evaluation (every eval_frequency
      steps). If None or 0, eval disabled.
    num_devices: how many devices to use (if None, default, use all available)
    run_debug_step: bool, if True, will run the model and loss without @jit for
      one step.

  Returns:
    trax.State
  """
  num_devices = num_devices or jax.lib.xla_bridge.device_count()
  rng = random.PRNGKey(0)
  gfile.makedirs(output_dir)
  # Create summary writers and history.
  train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train"))
  eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval"))

  inputs = inputs()

  # Setup optimizer and model
  state = restore_state(output_dir)
  history = state.history
  lr_fun = lr_schedule(history)
  opt_init, _ = optimizer(lr_fun)
  model_init, model_predict_original = model()
  # We need a model_predict that fills in the random generator if needed.
  def model_predict(x, y, **kwargs):
    """Same as model_predict_original but fill in rng if it isn't passed."""
    if "rng" in kwargs:
      return model_predict_original(x, y, **kwargs)
    return model_predict_original(x, y, rng=rng, **kwargs)

  # Setup state
  step = state.step or 0
  params_initializer = lambda: model_init([-1] + list(inputs.input_shape))[1]
  params = state.params or params_initializer()
  opt_state = jax.replicate(opt_init(params))

  # jit model_predict and update so they're fast
  jit_model_predict = jax.jit(model_predict)  # for evaluation
  jit_update_fun = _jit_update_fun(model_predict, loss, optimizer, lr_fun)

  print()
  train_stream = inputs.train_stream()
  epoch_steps = itertools.chain([1,  # first epoch only 1 step
                                 eval_frequency - 1],
                                itertools.repeat(eval_frequency))
  step_log(step, "Starting training using %d devices" % num_devices)

  # Non-compiled debug step helps find problems in models easier.
  if run_debug_step:
    debug_loss = loss(params, next(train_stream), model_predict)
    step_log(step, "Debug step loss %.8f" % debug_loss)

  for epoch, epoch_steps in epochs(train_steps, epoch_steps):
    # Log separator
    print()

    # Timer
    start_time = time.time()

    for _ in range(epoch_steps):
      # Train
      next_train_batch = reshape_by_device(next(train_stream), num_devices)
      opt_state = jit_update_fun(step, opt_state, next_train_batch)
      step += 1

      # LR log
      if step == 1 or step % 10 == 0:
        train_sw.scalar("training/learning rate",
                        lr_fun(step), step=step)

    # Timer
    epoch_time = time.time() - start_time
    step_log(step, "Ran %d train steps in %0.2f secs" %
             (epoch_steps, epoch_time))
    if epoch_steps > 1:
      train_sw.scalar("training/steps per second",
                      epoch_steps / epoch_time, step=step)

    # Evaluate
    params = jax_opt.get_params(jax.unreplicate(opt_state))
    evaluate_train_and_eval(
        step=step,
        inputs=inputs,
        predict_fun=functools.partial(jit_model_predict, params),
        eval_steps=eval_steps,
        train_sw=train_sw,
        eval_sw=eval_sw,
        history=history)

    # Save state
    save_state(State(params=params, step=step, history=history), output_dir)

    # Save Gin config
    # Gin only tracks the used parameters, so we save it after the first epoch.
    if epoch == 1:
      save_gin(output_dir, train_sw)

    # Update learning rate with new history
    old_lr_fun = lr_fun
    lr_fun = lr_schedule(history)
    if lr_fun != old_lr_fun:  # For performance, only jit if there is a change.
      jit_update_fun = _jit_update_fun(model_predict, loss, optimizer, lr_fun)

    # Flush summary writers
    train_sw.writer.flush()
    eval_sw.writer.flush()

  step_log(step, "Training done")
  return State(params=params, step=step, history=history)
Exemplo n.º 7
0
 def update(i, opt_state, batch, rng):
     # TODO(lukaszkaiser): investigate how to replicate rng and correct.
     return backend.pmap(mapped_update(jax.replicate(i), opt_state, batch,
                                       jax.replicate(rng)),
                         axis_name="batch")