def aggregate_results(cls, results): """ Aggregate multiple processes' "run_epoch" results into a single result. :param results: A list of return values from run_epoch from different processes. :type results: list :return: A single result dict with results aggregated. :rtype: dict """ ret = cls.aggregate_validation_results(results) extra_val_aggregated = [] for i in range(len(ret["extra_val_results"])): timestep = ret["extra_val_results"][i][0] val_results = [process_result["extra_val_results"][i][1] for process_result in results] extra_val_aggregated.append( (timestep, aggregate_eval_results(val_results)) ) ret["extra_val_results"] = extra_val_aggregated return ret
def aggregate_results(cls, results): ret = super().aggregate_results(results) extra_val_aggregated = [] for i in range(len(ret["extra_val_results"])): timestep = ret["extra_val_results"][i][0] val_results = [ process_result["extra_val_results"][i][1] for process_result in results ] extra_val_aggregated.append( (timestep, aggregate_eval_results(val_results))) ret["extra_val_results"] = extra_val_aggregated return ret
def _train(self): self.logger.debug( f"_train: {self._trial_info.trial_name}({self.iteration})") try: # Check if restore checkpoint file fulfills the stop criteria on first run if self._first_run: self._first_run = False if self._restored and self._should_stop(): self.logger.warning( f"Restored checkpoint file '{self._checkpoint_file}' fulfills " f"stop criteria without additional training.") return { # do not train or log results, just stop RESULT_DUPLICATE: True, DONE: True } status = [] for w in self.procs: status.append(w.run_epoch.remote()) # Wait for remote functions and check for errors # Aggregate the results from all processes if ray_utils.check_for_failure(status): results = ray.get(status) ret = copy.deepcopy(results[0]) ret.update(aggregate_eval_results(results)) self._process_result(ret) # Check if we should stop the experiment ret[DONE] = self._should_stop() return ret err_msg = (f"{self._trial_info.trial_name}({self.iteration}): " f"One of the remote workers failed during training") self.logger.error(err_msg) raise RuntimeError(err_msg) except Exception: self._kill_workers() raise
def aggregate_validation_results(cls, results): """ Aggregate multiple processes' "validate" results into a single result. This method exists separately from "aggregate_results" to support running validation outside of "run_epoch" and aggregating those results without causing error. Subclasses / mixins implementing "aggregate_results" may expect all results to have the extra data appended during run_epoch. :param results: A list of return values from validate from different processes. :type results: list :return: A single result dict with results aggregated. :rtype: dict """ result = copy.deepcopy(results[0]) result.update(aggregate_eval_results(results)) return result
def _aggregate_validation_results(cls, results): result = copy.copy(results[0]) result.update(aggregate_eval_results(results)) return result