Example #1
0
def train(experiment_class, config, checkpointer, writer, periodic_actions=()):
    """Main training loop."""
    logging.info("Training with config:\n%s", config)
    is_chief = jax.host_id() == 0

    rng = jax.random.PRNGKey(config.random_seed)
    with utils.log_activity("experiment init"):
        experiment = _initialize_experiment(experiment_class, "train", rng,
                                            config.experiment_kwargs)

    state = checkpointer.get_experiment_state("latest")
    state.global_step = 0
    state.experiment_module = experiment
    state.train_step_rng = jnp.broadcast_to(rng, (jax.local_device_count(), ) +
                                            rng.shape)

    if checkpointer.can_be_restored("latest"):
        with utils.log_activity("checkpoint restore"):
            checkpointer.restore("latest")

    periodic_actions += (utils.PeriodicAction(
        _log_outputs,
        interval_type=config.interval_type,
        interval=config.log_tensors_interval), )

    if config.train_checkpoint_all_hosts or is_chief:
        if config.save_checkpoint_interval > 0:
            periodic_actions += (utils.PeriodicAction(
                lambda *_: checkpointer.save("latest"),
                interval_type=config.interval_type,
                interval=config.save_checkpoint_interval,
                run_async=False), )  # run_async True would not be thread-safe.

    if is_chief:
        if writer is not None:

            def write_scalars(global_step: int, scalar_values):
                writer.write_scalars(global_step, scalar_values)

            periodic_actions += (utils.PeriodicAction(
                write_scalars,
                interval_type=config.interval_type,
                interval=config.log_train_data_interval,
                log_all_data=config.log_all_train_data), )

    for pa in periodic_actions:
        pa.update_time(time.time(), state.global_step)

    experiment.train_loop(config, state, periodic_actions, writer)

    if is_chief:
        with utils.log_activity("final checkpoint"):
            checkpointer.save("latest")

    # We occasionally see errors when the final checkpoint is being written if
    # the other hosts exit. Here we force all hosts to participate in one final
    # collective so the non-master hosts cannot exit before the master writes out
    # the final checkpoint.
    utils.rendezvous()
Example #2
0
    def test_log_success(self, mock_info):
        """Tests that logging an activity is successful."""

        with utils.log_activity("for test"):
            pass

        mock_info.assert_any_call("[jaxline] %s starting...", "for test")
        mock_info.assert_any_call("[jaxline] %s finished.", "for test")
Example #3
0
  def train_loop(
      self,
      config: config_dict.ConfigDict,
      state,
      periodic_actions: List[utils.PeriodicAction],
      writer: Optional[utils.Writer] = None,
  ) -> None:
    """Default training loop implementation.

    Can be overridden for advanced use cases that need a different training loop
    logic, e.g. on device training loop with jax.lax.while_loop or to add custom
    periodic actions.

    Args:
      config: The config of the experiment that is being run.
      state: Checkpointed state of the experiment.
      periodic_actions: List of actions that should be called after every
        training step, for checkpointing and logging.
      writer: An optional writer to pass to the experiment step function.
    """

    @functools.partial(jax.pmap, axis_name="i")
    def next_device_state(
        global_step: jnp.ndarray,
        rng: jnp.ndarray,
        host_id: Optional[jnp.ndarray],
    ):
      """Updates device global step and rng in one pmap fn to reduce overhead."""
      global_step += 1
      step_rng, state_rng = tuple(jax.random.split(rng))
      step_rng = utils.specialize_rng_host_device(
          step_rng, host_id, axis_name="i", mode=config.random_mode_train)
      return global_step, (step_rng, state_rng)

    global_step_devices = np.broadcast_to(state.global_step,
                                          [jax.local_device_count()])
    host_id_devices = utils.host_id_devices_for_rng(config.random_mode_train)
    step_key = state.train_step_rng

    with utils.log_activity("training loop"):
      while self.should_run_step(state.global_step, config):
        with jax.profiler.StepTraceAnnotation(
            "train", step_num=state.global_step):
          scalar_outputs = self.step(
              global_step=global_step_devices, rng=step_key, writer=writer)

          t = time.time()
          # Update state's (scalar) global step (for checkpointing).
          # global_step_devices will be back in sync with this after the call
          # to next_device_state below.
          state.global_step += 1
          global_step_devices, (step_key, state.train_step_rng) = (
              next_device_state(global_step_devices,
                                state.train_step_rng,
                                host_id_devices))

        for action in periodic_actions:
          action(t, state.global_step, scalar_outputs)
Example #4
0
    def test_log_failure(self, mock_info, mock_exc):
        """Tests that an error thrown by an activity is correctly caught."""

        with self.assertRaisesRegex(ValueError, "Intentional"):
            with utils.log_activity("for test"):
                raise ValueError("Intentional")

        mock_info.assert_any_call("[jaxline] %s starting...", "for test")
        mock_exc.assert_any_call("[jaxline] %s failed with error.", "for test")
Example #5
0
def evaluate(
    experiment_class,
    config,
    checkpointer: utils.Checkpointer,
    writer: Optional[utils.Writer],
    jaxline_mode: Optional[str] = None,
):
    """Main evaluation loop."""
    if jaxline_mode is None:
        jaxline_mode = FLAGS.jaxline_mode

    logging.info("Evaluating with config:\n%s", config)

    global_step = 0
    eval_rng = jax.random.PRNGKey(config.random_seed)
    experiment = _initialize_experiment(experiment_class, jaxline_mode,
                                        eval_rng, config.experiment_kwargs)

    if config.best_model_eval_metric and jax.host_id() == 0:
        # Initialize best state.
        best_state = checkpointer.get_experiment_state("best")
        if config.best_model_eval_metric_higher_is_better:
            best_state.best_eval_metric_value = float("-inf")
            eval_metric_is_better_op = jnp.greater
            eval_metric_comparison_str = ">"
        else:
            best_state.best_eval_metric_value = float("inf")
            eval_metric_is_better_op = jnp.less
            eval_metric_comparison_str = "<"
        best_state.best_model_eval_metric = config.best_model_eval_metric

        best_state.experiment_module = experiment

        # Restore to preserve 'best_eval_metric_value' if evaluator was preempted.
        if checkpointer.can_be_restored("best"):
            with utils.log_activity("best checkpoint restore"):
                checkpointer.restore("best")

    # Will evaluate the latest checkpoint in the directory.
    state = checkpointer.get_experiment_state("latest")
    state.global_step = global_step
    state.experiment_module = experiment
    state.train_step_rng = None

    eval_rng = jnp.broadcast_to(eval_rng,
                                (jax.local_device_count(), ) + eval_rng.shape)
    host_id_devices = utils.host_id_devices_for_rng(config.random_mode_eval)
    eval_rng = jax.pmap(functools.partial(utils.specialize_rng_host_device,
                                          axis_name="i",
                                          mode=config.random_mode_eval),
                        axis_name="i")(eval_rng, host_id_devices)

    if config.one_off_evaluate:
        checkpointer.restore("latest")
        global_step_devices = utils.bcast_local_devices(
            jnp.asarray(state.global_step))
        scalar_values = utils.evaluate_should_return_dict(experiment.evaluate)(
            global_step=global_step_devices, rng=eval_rng, writer=writer)
        if writer is not None:
            writer.write_scalars(state.global_step, scalar_values)
        logging.info("Evaluated specific checkpoint, exiting.")
        return

    old_checkpoint_path = None
    initial_weights_are_evaluated = False

    while True:
        checkpoint_path = checkpointer.restore_path("latest")

        if (checkpoint_path is None and config.eval_initial_weights
                and not initial_weights_are_evaluated):
            # Skip restoring a checkpoint and directly call evaluate if
            # `config.eval_initial_weights` but don"t do it more than once.
            initial_weights_are_evaluated = True
        else:
            if checkpoint_path in (None, old_checkpoint_path):
                logging.info(
                    "Checkpoint %s invalid or already evaluated, waiting.",
                    checkpoint_path)
                time.sleep(10)
                continue

            checkpointer.restore("latest")

        global_step_devices = utils.bcast_local_devices(
            jnp.asarray(state.global_step))
        scalar_values = utils.evaluate_should_return_dict(experiment.evaluate)(
            global_step=global_step_devices, rng=eval_rng, writer=writer)
        if writer is not None:
            writer.write_scalars(state.global_step, scalar_values)
        old_checkpoint_path = checkpoint_path
        # Decide whether to save a "best checkpoint".
        if config.best_model_eval_metric and jax.host_id() == 0:
            if config.best_model_eval_metric not in scalar_values:
                raise ValueError(
                    f"config.best_model_eval_metric has been specified "
                    f"as {config.best_model_eval_metric}, but this key "
                    f"was not returned by the evaluate method. Got: "
                    f"{scalar_values.keys()}")
            current_eval_metric_value = scalar_values[
                config.best_model_eval_metric]
            old_eval_metric_value = best_state.best_eval_metric_value
            if eval_metric_is_better_op(current_eval_metric_value,
                                        old_eval_metric_value):
                logging.info("%s: %s %s %s, saving new best checkpoint.",
                             config.best_model_eval_metric,
                             current_eval_metric_value,
                             eval_metric_comparison_str, old_eval_metric_value)
                best_state.global_step = state.global_step
                best_state.experiment_module = experiment
                best_state.best_eval_metric_value = current_eval_metric_value
                best_state.train_step_rng = state.train_step_rng
                checkpointer.save("best")

        if not experiment.should_run_step(state.global_step, config):
            logging.info("Last checkpoint (iteration %d) evaluated, exiting.",
                         state.global_step)
            break