Exemple #1
0
def save_state(state, output_dir, keep=False):
    """Save State and optionally gin config."""
    # TODO(gilmer, lukaszkaiser): figure out how to use cloudpickle in python3.
    # Currently the code throws an error when run in python3.
    if sys.version_info[0] < 3:
        pkl_module = cloudpickle
    else:
        pkl_module = pickle
    params_file = os.path.join(output_dir, "model.pkl")
    params = jax.unreplicate(state.params)
    with gfile.GFile(params_file, "wb") as f:
        pkl_module.dump((params, state.step, state.history), f)
    if keep:
        params_file = os.path.join(output_dir,
                                   "model_{}.pkl".format(state.step))
        with gfile.GFile(params_file, "wb") as f:
            pkl_module.dump((params, state.step, state.history), f)
    log("Model saved to %s" % params_file, stdout=False)
Exemple #2
0
def train(output_dir,
          model=gin.REQUIRED,
          loss_fun=loss,
          inputs=trax_inputs.inputs,
          optimizer=trax_opt.adam,
          lr_schedule=lr.MultifactorSchedule,
          train_steps=1000,
          eval_steps=10,
          eval_frequency=100,
          num_devices=None,
          random_seed=None,
          run_debug_step=False):
    """Train the model on the inputs.

  Args:
    output_dir: Directory where to put the logs and checkpoints.
    model: The model to train as a callable returning 2 callables, an init_fun
      and apply_fun.
    loss_fun: callable with signature: params, trax.inputs.Inputs, model, rng
      -> loss.
    inputs: callable returning trax.inputs.Inputs.
    optimizer: The optimizer as a callable taking a learning_rate callable and
      returning 2 callables, opt_init and opt_update.
    lr_schedule: A learning rate schedule as a function that takes history and
      returns a function from step to learning rate (a float).
    train_steps: int, total number of training steps.
    eval_steps: int, num of steps per evaluation. If None or 0, eval disabled.
    eval_frequency: int, how often to run evaluation (every eval_frequency
      steps). If None or 0, eval disabled.
    num_devices: how many devices to use (if None, default, use all available)
    random_seed: the random seed to use; time/os dependent if None (default).
    run_debug_step: bool, if True, will run the model and loss without @jit for
      one step.

  Returns:
    trax.State
  """
    num_devices = num_devices or jax.lib.xla_bridge.device_count()
    rng = get_random_number_generator_and_set_seed(random_seed)
    gfile.makedirs(output_dir)
    # Create summary writers and history.
    train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "train"))
    eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, "eval"))

    inputs = inputs()

    # Setup optimizer and model
    state = restore_state(output_dir)
    history = state.history
    lr_fun = lr_schedule(history)
    opt_init, _ = optimizer(lr_fun)
    model_init, model_predict = model()

    # Setup state
    step = state.step or 0
    rng, init_key = jax_random.split(rng)
    params_initializer = \
        lambda: model_init(init_key, [-1] + list(inputs.input_shape))[1]
    params = state.params or params_initializer()
    opt_state = opt_init(params)
    if num_devices > 1:  # TODO(lukaszkaiser): use everywhere when pmap is stable.
        opt_state = jax.replicate(opt_state)

    # jit model_predict and update so they're fast
    jit_model_predict = backend.jit(model_predict)  # for evaluation
    jit_update_fun = _jit_update_fun(model_predict, loss_fun, optimizer,
                                     lr_fun, num_devices)

    print()
    train_stream = inputs.train_stream()
    epoch_steps = [train_steps
                   ]  # Only training if eval_frequency is 0 or None.
    if eval_frequency:
        epoch_steps = itertools.chain(
            [
                1,  # first epoch only 1 step
                eval_frequency - 1
            ],
            itertools.repeat(eval_frequency))
    step_log(step, "Starting training using %d devices" % num_devices)

    # Non-compiled debug step helps find problems in models easier.
    if run_debug_step:
        debug_loss = loss_fun(params, next(train_stream), model_predict, rng)
        step_log(step, "Debug step loss %.8f" % debug_loss)

    for epoch, epoch_steps in epochs(train_steps, epoch_steps):
        # Log separator
        print()

        # Timer
        start_time = time.time()

        for _ in range(epoch_steps):
            # Train
            next_train_batch = next(train_stream)
            if num_devices > 1:  # TODO(lukaszkaiser): use everywhere when possible.
                next_train_batch = reshape_by_device(next_train_batch,
                                                     num_devices)
            rng, subrng = jax_random.split(rng)
            opt_state = jit_update_fun(step, opt_state, next_train_batch,
                                       subrng)
            step += 1

            # LR log
            if step == 1 or step % 10 == 0:
                train_sw.scalar("training/learning rate",
                                lr_fun(step),
                                step=step)

        # Timer
        epoch_time = time.time() - start_time
        step_log(
            step,
            "Ran %d train steps in %0.2f secs" % (epoch_steps, epoch_time))
        if epoch_steps > 1:
            train_sw.scalar("training/steps per second",
                            epoch_steps / epoch_time,
                            step=step)

        # Evaluate
        if num_devices > 1:  # TODO(lukaszkaiser): remove branch when possible.
            params = trax_opt.get_params(jax.unreplicate(opt_state))
        else:
            params = trax_opt.get_params(opt_state)
        evaluate_train_and_eval(step=step,
                                inputs=inputs,
                                predict_fun=functools.partial(
                                    jit_model_predict, params),
                                eval_steps=eval_steps,
                                rng=rng,
                                train_sw=train_sw,
                                eval_sw=eval_sw,
                                history=history)

        # Save state
        save_state(State(params=params, step=step, history=history),
                   output_dir)

        # Save Gin config
        # Gin only tracks the used parameters, so we save it after the first epoch.
        if epoch == 1:
            save_gin(output_dir, train_sw)

        # Update learning rate with new history
        old_lr_fun = lr_fun
        lr_fun = lr_schedule(history)
        if lr_fun != old_lr_fun:  # For performance, only jit if there is a change.
            jit_update_fun = _jit_update_fun(model_predict, loss_fun,
                                             optimizer, lr_fun, num_devices)

        # Flush summary writers
        train_sw.writer.flush()
        eval_sw.writer.flush()

    step_log(step, "Training done")
    return State(params=params, step=step, history=history)
    batches = data_stream()

    @partial(pmap, axis_name='batch')
    def spmd_update(params, batch):
        grads = grad(loss)(params, batch)
        # We compute the total gradients, summing across the device-mapped axis,
        # using the `lax.psum` SPMD primitive, which does a fast all-reduce-sum.
        grads = [(lax.psum(dw, 'batch'), lax.psum(db, 'batch'))
                 for dw, db in grads]
        return [(w - step_size * dw, b - step_size * db)
                for (w, b), (dw, db) in zip(params, grads)]

    # We replicate parameters out across devices. (Check the implementation of
    # replicate; analogous to device_put, it's a simple wrapper around pmap.)
    params = replicate(init_random_params(param_scale, layer_sizes))

    for epoch in range(num_epochs):
        start_time = time.time()
        for _ in range(num_batches):
            params = spmd_update(params, next(batches))
        epoch_time = time.time() - start_time

        # We evaluate using the jitted `accuracy` function (not using pmap) by
        # grabbing just one of the replicated parameter values.
        train_acc = accuracy(unreplicate(params), (train_images, train_labels))
        test_acc = accuracy(unreplicate(params), (test_images, test_labels))
        print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
        print("Training set accuracy {}".format(train_acc))
        print("Test set accuracy {}".format(test_acc))