Esempio n. 1
0
    def maybe_eval_and_log(self, eval_summary, master, step, tick,
                           train_metrics, train_summary):
        """Maybe evaluate and log based on the current step value."""
        if (step % self.eval_frequency == 0) or (step == self.total_steps):
            del eval_summary
            del train_summary

            train_metrics = common_utils.get_metrics(train_metrics)
            train_summary = pipeline_utils.compute_global_mean_metrics(
                train_metrics)

            tock = time.time()
            steps_per_sec = self.eval_frequency / (tock - tick)
            tick = tock

            # log train summary
            if master:
                self.write_train_summary(step=step,
                                         metric_dict=train_metrics,
                                         summary=train_summary,
                                         steps_per_sec=steps_per_sec)
            # reset metric accumulation for next evaluation cycle
            del train_metrics
            train_metrics = []

            # sync model state across replicas
            self.train_state = pipeline_utils.sync_model_state_across_replicas(
                self.train_state)

            # evaluate and log the results
            eval_summary, _ = self.eval(step, self.train_state)
        return eval_summary, train_metrics, train_summary, tick
    def maybe_eval_and_log(self, eval_env_ids, eval_summary, master, step,
                           tick, train_metrics, train_summary):
        if (step % self.eval_frequency == 0) or (step == self.total_steps):
            train_metrics = jax.device_get(train_metrics)
            train_metrics = common_utils.stack_forest(train_metrics)
            train_summary = pipeline_utils.compute_global_mean_metrics(
                train_metrics)
            tock = time.time()
            steps_per_sec = self.eval_frequency / (tock - tick)
            tick = tock

            # Log train summary:
            if master:
                self.write_train_summary(step=step,
                                         metric_dict=train_metrics,
                                         summary=train_summary,
                                         steps_per_sec=steps_per_sec)

            # Reset metric accumulation for next evaluation cycle:
            train_metrics = []

            # Sync model state across replicas:
            self.train_state = pipeline_utils.sync_model_state_across_replicas(
                self.train_state)

            # Evaluate and log the results:
            eval_summary, self.train_state = self.eval(step, self.train_state,
                                                       eval_env_ids)
        return eval_summary, train_metrics, train_summary, tick
    def eval_split(self, train_state, split_name, eval_env_ids=None):
        """Evaluation loop on the specified split.

    Args:
      train_state: TrainState; Object containing training state.
      split_name: str; Name of the data split we want to evaluate the model on.
      eval_env_ids: list(int); Eval environments ids.

    Returns:
      eval_summary, train_state
    """
        data_iters = self.task.dataset.data_iters[split_name]
        if eval_env_ids is None:
            eval_env_ids = list(map(int, data_iters.keys()))

        eval_metrics = {}
        if isinstance(self.steps_per_eval, dict):
            for env_id in eval_env_ids:
                env_id_str = str(env_id)
                env_eval_metrics = []
                for _ in range(self.steps_per_eval[split_name][env_id_str]):
                    env_eval_batches = self.get_next_batch(
                        [data_iters[env_id_str]])
                    e_metrics = self.pmapped_eval_step(train_state,
                                                       env_eval_batches,
                                                       env_id)
                    env_eval_metrics.append(e_metrics)

                env_eval_metrics = common_utils.get_metrics(env_eval_metrics)
                eval_metrics.update(env_eval_metrics)

            eval_summary = pipeline_utils.compute_global_mean_metrics(
                eval_metrics)
        else:
            _, data_iters = list(zip(*dict(data_iters).items()))
            eval_metrics = []
            for _ in range(self.steps_per_eval):
                env_eval_batches = self.get_next_batch(data_iters)
                e_metrics = self.pmapped_eval_step(train_state,
                                                   env_eval_batches, -1)
                eval_metrics.append(e_metrics)

            eval_metrics = common_utils.get_metrics(eval_metrics)
            eval_summary = pipeline_utils.compute_global_mean_metrics(
                eval_metrics)

        return eval_summary, eval_metrics
Esempio n. 4
0
    def eval_split(self, train_state, split_name):
        """Evaluation loop on the specified split.

    Args:
      train_state: TrainState; Object containing training state.
      split_name: str; Name of the data split we want to evaluate the model on.

    Returns:
      eval_summary, train_state
    """
        data_iters = self.task.dataset.data_iters[split_name]
        eval_metrics = []
        for _ in range(self.steps_per_eval):
            env_eval_batches = self.get_next_batch(data_iters)
            e_metrics = self.pmapped_eval_step(train_state, env_eval_batches)
            eval_metrics.append(e_metrics)

        eval_metrics = common_utils.get_metrics(eval_metrics)
        eval_summary = pipeline_utils.compute_global_mean_metrics(eval_metrics)

        return eval_summary, eval_metrics
Esempio n. 5
0
    def _train_loop(self, environments, start_step, end_step, master):
        """Training loop.

    Trains the model on the given environment set for (end_step - start_step)
    number of steps.

    Args:
      environments: dict; A dictionary from environment name to environment data
        iterator.
      start_step: int; Staring step in the loop.
      end_step: int; End step in the loop.
      master: bool; Is this the host device? If yes, log and checkpoint.

    Returns:
      Evaluation summaries and metrics.
    """
        # Initialize return values.
        train_metrics = []
        train_summary, eval_summary = None, None
        tick = time.time()

        eval_env_ids = list(
            map(int, self.task.dataset.data_iters.validation.keys()))
        train_env_ids, train_iters = list(zip(*dict(environments).items()))
        train_env_ids = list(map(int, train_env_ids))

        for step in range(start_step + 1, end_step + 1):

            # Get next batch.
            train_batch = self.get_next_batch(train_iters)

            # Run train step and get the metrics and the new train state.
            self.train_state, t_metrics = self.pmapped_train_step(
                self.train_state, train_batch, train_env_ids)
            train_metrics.append(t_metrics)

            if (step % self.eval_frequency == 0) or (step == end_step):
                train_metrics = common_utils.get_metrics(train_metrics)
                train_summary = pipeline_utils.compute_global_mean_metrics(
                    train_metrics)

                tock = time.time()
                steps_per_sec = self.eval_frequency / (tock - tick)
                tick = tock

                # Log train summary:
                if master:
                    self.write_train_summary(step=step,
                                             metric_dict=train_metrics,
                                             summary=train_summary,
                                             steps_per_sec=steps_per_sec)

                # Reset metric accumulation for next evaluation cycle.
                train_metrics = []

                # Sync model state across replicas.
                self.train_state = pipeline_utils.sync_model_state_across_replicas(
                    self.train_state)

                # Evaluate and log the results.
                eval_summary, self.train_state = self.eval(
                    step, self.train_state, eval_env_ids)

            # Sync and save.
            self.checkpoint(self.train_state, step)

        return eval_summary, train_summary