def maybe_eval_and_log(self, eval_summary, master, step, tick, train_metrics, train_summary): """Maybe evaluate and log based on the current step value.""" if (step % self.eval_frequency == 0) or (step == self.total_steps): del eval_summary del train_summary train_metrics = common_utils.get_metrics(train_metrics) train_summary = pipeline_utils.compute_global_mean_metrics( train_metrics) tock = time.time() steps_per_sec = self.eval_frequency / (tock - tick) tick = tock # log train summary if master: self.write_train_summary(step=step, metric_dict=train_metrics, summary=train_summary, steps_per_sec=steps_per_sec) # reset metric accumulation for next evaluation cycle del train_metrics train_metrics = [] # sync model state across replicas self.train_state = pipeline_utils.sync_model_state_across_replicas( self.train_state) # evaluate and log the results eval_summary, _ = self.eval(step, self.train_state) return eval_summary, train_metrics, train_summary, tick
def maybe_eval_and_log(self, eval_env_ids, eval_summary, master, step, tick, train_metrics, train_summary): if (step % self.eval_frequency == 0) or (step == self.total_steps): train_metrics = jax.device_get(train_metrics) train_metrics = common_utils.stack_forest(train_metrics) train_summary = pipeline_utils.compute_global_mean_metrics( train_metrics) tock = time.time() steps_per_sec = self.eval_frequency / (tock - tick) tick = tock # Log train summary: if master: self.write_train_summary(step=step, metric_dict=train_metrics, summary=train_summary, steps_per_sec=steps_per_sec) # Reset metric accumulation for next evaluation cycle: train_metrics = [] # Sync model state across replicas: self.train_state = pipeline_utils.sync_model_state_across_replicas( self.train_state) # Evaluate and log the results: eval_summary, self.train_state = self.eval(step, self.train_state, eval_env_ids) return eval_summary, train_metrics, train_summary, tick
def eval_split(self, train_state, split_name, eval_env_ids=None): """Evaluation loop on the specified split. Args: train_state: TrainState; Object containing training state. split_name: str; Name of the data split we want to evaluate the model on. eval_env_ids: list(int); Eval environments ids. Returns: eval_summary, train_state """ data_iters = self.task.dataset.data_iters[split_name] if eval_env_ids is None: eval_env_ids = list(map(int, data_iters.keys())) eval_metrics = {} if isinstance(self.steps_per_eval, dict): for env_id in eval_env_ids: env_id_str = str(env_id) env_eval_metrics = [] for _ in range(self.steps_per_eval[split_name][env_id_str]): env_eval_batches = self.get_next_batch( [data_iters[env_id_str]]) e_metrics = self.pmapped_eval_step(train_state, env_eval_batches, env_id) env_eval_metrics.append(e_metrics) env_eval_metrics = common_utils.get_metrics(env_eval_metrics) eval_metrics.update(env_eval_metrics) eval_summary = pipeline_utils.compute_global_mean_metrics( eval_metrics) else: _, data_iters = list(zip(*dict(data_iters).items())) eval_metrics = [] for _ in range(self.steps_per_eval): env_eval_batches = self.get_next_batch(data_iters) e_metrics = self.pmapped_eval_step(train_state, env_eval_batches, -1) eval_metrics.append(e_metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_summary = pipeline_utils.compute_global_mean_metrics( eval_metrics) return eval_summary, eval_metrics
def eval_split(self, train_state, split_name): """Evaluation loop on the specified split. Args: train_state: TrainState; Object containing training state. split_name: str; Name of the data split we want to evaluate the model on. Returns: eval_summary, train_state """ data_iters = self.task.dataset.data_iters[split_name] eval_metrics = [] for _ in range(self.steps_per_eval): env_eval_batches = self.get_next_batch(data_iters) e_metrics = self.pmapped_eval_step(train_state, env_eval_batches) eval_metrics.append(e_metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_summary = pipeline_utils.compute_global_mean_metrics(eval_metrics) return eval_summary, eval_metrics
def _train_loop(self, environments, start_step, end_step, master): """Training loop. Trains the model on the given environment set for (end_step - start_step) number of steps. Args: environments: dict; A dictionary from environment name to environment data iterator. start_step: int; Staring step in the loop. end_step: int; End step in the loop. master: bool; Is this the host device? If yes, log and checkpoint. Returns: Evaluation summaries and metrics. """ # Initialize return values. train_metrics = [] train_summary, eval_summary = None, None tick = time.time() eval_env_ids = list( map(int, self.task.dataset.data_iters.validation.keys())) train_env_ids, train_iters = list(zip(*dict(environments).items())) train_env_ids = list(map(int, train_env_ids)) for step in range(start_step + 1, end_step + 1): # Get next batch. train_batch = self.get_next_batch(train_iters) # Run train step and get the metrics and the new train state. self.train_state, t_metrics = self.pmapped_train_step( self.train_state, train_batch, train_env_ids) train_metrics.append(t_metrics) if (step % self.eval_frequency == 0) or (step == end_step): train_metrics = common_utils.get_metrics(train_metrics) train_summary = pipeline_utils.compute_global_mean_metrics( train_metrics) tock = time.time() steps_per_sec = self.eval_frequency / (tock - tick) tick = tock # Log train summary: if master: self.write_train_summary(step=step, metric_dict=train_metrics, summary=train_summary, steps_per_sec=steps_per_sec) # Reset metric accumulation for next evaluation cycle. train_metrics = [] # Sync model state across replicas. self.train_state = pipeline_utils.sync_model_state_across_replicas( self.train_state) # Evaluate and log the results. eval_summary, self.train_state = self.eval( step, self.train_state, eval_env_ids) # Sync and save. self.checkpoint(self.train_state, step) return eval_summary, train_summary