Example #1
0
    def checkpoint(self, train_state, step):
        """Saves checkpoint.

    Syncs the model state across replicas if needed.

    Args:
      train_state: TrainSate; A flax struct that keeps model state and optimizer
        state.
      step: int; Number of steps passes so far during training.

    Returns:
      train_state
    """
        checkpoint_flag = False
        if self.hparams.get('ckpnt_steps', None) and self.hparams.checkpoint:
            if step in self.hparams.get('ckpnt_steps'):
                checkpoint_flag = True
        elif ((step % self.checkpoint_frequency == 0) or
              (step == self.total_steps)) and self.hparams.checkpoint:
            checkpoint_flag = True

        if checkpoint_flag:
            # Sync model state across replicas.
            train_state = pipeline_utils.sync_model_state_across_replicas(
                train_state)
            if jax.host_id() == 0:
                pipeline_utils.save_checkpoint(self.experiment_dir,
                                               train_state,
                                               keep=self.hparams.keep_ckpts)

        return train_state
Example #2
0
    def maybe_eval_and_log(self, eval_summary, master, step, tick,
                           train_metrics, train_summary):
        """Maybe evaluate and log based on the current step value."""
        if (step % self.eval_frequency == 0) or (step == self.total_steps):
            del eval_summary
            del train_summary

            train_metrics = common_utils.get_metrics(train_metrics)
            train_summary = pipeline_utils.compute_global_mean_metrics(
                train_metrics)

            tock = time.time()
            steps_per_sec = self.eval_frequency / (tock - tick)
            tick = tock

            # log train summary
            if master:
                self.write_train_summary(step=step,
                                         metric_dict=train_metrics,
                                         summary=train_summary,
                                         steps_per_sec=steps_per_sec)
            # reset metric accumulation for next evaluation cycle
            del train_metrics
            train_metrics = []

            # sync model state across replicas
            self.train_state = pipeline_utils.sync_model_state_across_replicas(
                self.train_state)

            # evaluate and log the results
            eval_summary, _ = self.eval(step, self.train_state)
        return eval_summary, train_metrics, train_summary, tick
    def maybe_eval_and_log(self, eval_env_ids, eval_summary, master, step,
                           tick, train_metrics, train_summary):
        if (step % self.eval_frequency == 0) or (step == self.total_steps):
            train_metrics = jax.device_get(train_metrics)
            train_metrics = common_utils.stack_forest(train_metrics)
            train_summary = pipeline_utils.compute_global_mean_metrics(
                train_metrics)
            tock = time.time()
            steps_per_sec = self.eval_frequency / (tock - tick)
            tick = tock

            # Log train summary:
            if master:
                self.write_train_summary(step=step,
                                         metric_dict=train_metrics,
                                         summary=train_summary,
                                         steps_per_sec=steps_per_sec)

            # Reset metric accumulation for next evaluation cycle:
            train_metrics = []

            # Sync model state across replicas:
            self.train_state = pipeline_utils.sync_model_state_across_replicas(
                self.train_state)

            # Evaluate and log the results:
            eval_summary, self.train_state = self.eval(step, self.train_state,
                                                       eval_env_ids)
        return eval_summary, train_metrics, train_summary, tick
    def train(self):
        """Training loop."""

        master = jax.host_id() == 0
        train_metrics = []
        train_summary, eval_summary = None, None
        tick = time.time()

        @jax.pmap
        def get_reps(train_state, flax_module, batch):
            with nn.stochastic(train_state.rng):
                with nn.stateful(train_state.model_state):
                    _, reps, _ = flax_module(batch['inputs'],
                                             train=True,
                                             return_activations=True)

            return reps

        # Prepare arguments for layer sampling:
        sample_batch = self.get_next_batch(self.task.dataset.data_iters.train)
        reps = get_reps(self.train_state, self.train_state.optimizer.target,
                        sample_batch)
        layer_keys, mixup_layers = pipeline_utils.get_sample_layer_params(
            self.hparams, reps)

        # Train loop:
        for step in range(self.start_step + 1, self.total_steps + 1):
            train_batch = self.get_next_batch(
                self.task.dataset.data_iters.train)
            sampled_layer = pipeline_utils.sample_layer(
                layer_keys, mixup_layers=mixup_layers)

            self.train_state, t_metrics = self.pmapped_train_step(
                self.train_state, train_batch, sampled_layer)
            train_metrics.append(t_metrics)

            eval_summary, train_metrics, train_summary, tick = self.maybe_eval_and_log(
                eval_summary, master, step, tick, train_metrics, train_summary)

            # sync and save
            self.checkpoint(self.train_state, step)

        if master:
            # evaluate and log the results
            # sync model state across replicas
            self.train_state = pipeline_utils.sync_model_state_across_replicas(
                self.train_state)
            eval_summary, self.train_state = self.eval(step, self.train_state,
                                                       'test')

        # wait until computations are done before exiting (for timing!)
        jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()

        # return the train and eval summary after last step for regresesion testing
        return train_summary, eval_summary
Example #5
0
  def eval(self, step, train_state, split_name='validation'):
    """Evaluation loop.

    Args:
      step: int; Training step.
      train_state: TrainState; Object containing training state.
      split_name: str; Name of the dataset split to evaluate on.

    Returns:
      eval_summart, train_state
    """
    train_state = pipeline_utils.sync_model_state_across_replicas(train_state)
    eval_summary, eval_metrics = self.eval_split(
        train_state=train_state, split_name=split_name)

    # log eval summary
    master = jax.host_id() == 0

    if master:
      self.write_eval_summary(
          step=step, metric_dict=eval_metrics, summary=eval_summary)
    return eval_summary, train_state
Example #6
0
    def _train_loop(self, environments, start_step, end_step, master):
        """Training loop.

    Trains the model on the given environment set for (end_step - start_step)
    number of steps.

    Args:
      environments: dict; A dictionary from environment name to environment data
        iterator.
      start_step: int; Staring step in the loop.
      end_step: int; End step in the loop.
      master: bool; Is this the host device? If yes, log and checkpoint.

    Returns:
      Evaluation summaries and metrics.
    """
        # Initialize return values.
        train_metrics = []
        train_summary, eval_summary = None, None
        tick = time.time()

        eval_env_ids = list(
            map(int, self.task.dataset.data_iters.validation.keys()))
        train_env_ids, train_iters = list(zip(*dict(environments).items()))
        train_env_ids = list(map(int, train_env_ids))

        for step in range(start_step + 1, end_step + 1):

            # Get next batch.
            train_batch = self.get_next_batch(train_iters)

            # Run train step and get the metrics and the new train state.
            self.train_state, t_metrics = self.pmapped_train_step(
                self.train_state, train_batch, train_env_ids)
            train_metrics.append(t_metrics)

            if (step % self.eval_frequency == 0) or (step == end_step):
                train_metrics = common_utils.get_metrics(train_metrics)
                train_summary = pipeline_utils.compute_global_mean_metrics(
                    train_metrics)

                tock = time.time()
                steps_per_sec = self.eval_frequency / (tock - tick)
                tick = tock

                # Log train summary:
                if master:
                    self.write_train_summary(step=step,
                                             metric_dict=train_metrics,
                                             summary=train_summary,
                                             steps_per_sec=steps_per_sec)

                # Reset metric accumulation for next evaluation cycle.
                train_metrics = []

                # Sync model state across replicas.
                self.train_state = pipeline_utils.sync_model_state_across_replicas(
                    self.train_state)

                # Evaluate and log the results.
                eval_summary, self.train_state = self.eval(
                    step, self.train_state, eval_env_ids)

            # Sync and save.
            self.checkpoint(self.train_state, step)

        return eval_summary, train_summary