def validate_allocation(self, allocation_dict): check( len(allocation_dict) <= self.total_atoms, f"Cannot track more than {self.total_atoms} atoms.", ) check( sum(allocation_dict.values()) <= self.total_atoms, f"Cannot allocate more than {self.total_atoms} atoms.", )
def try_execute_update(self, trial, trial_runner): check(self._resource_allocations.get(trial)) if self.should_execute_resource_update(trial, trial_runner): logger.info("Committing resource update.") new_resources = trial._get_trainable_cls().to_resources( self._resource_allocations[trial] ) self._commit_resource_update( trial, trial_runner.trial_executor, new_resources )
def step(self, resources, itr): check(not (resources > 0 and self._done), "Should not step on done job!") self._progress += int(self._scale_factor(resources)) self._iteration += 1 score = self.score() self._scores += [(self._progress, score)] if resources: self._progressed_scores += [(score, itr)] return score
def populate_if_needed(self, trials): for trial in trials: if ( trial not in self._resource_allocations and trial.status == "RUNNING" ): self._resource_allocations[trial] = self.get_current_atoms( trial ) logger.warning(f"Adding {trial} to track.") self.reallocation_timer[trial] = self.recharge_period for trial in list(self._resource_allocations): if trial not in trials: self._resource_allocations.pop(trial) self.reallocation_timer.pop(trial) check( len(self._resource_allocations) <= self._total_atoms, f"Cannot track more than {self.total_atoms} atoms.", ) self._initialized = True
def on_trial_result(self, trial_runner, trial, result): """First check ASHA Nothing should occur bc of plateau. """ decision = ASHAv2.on_trial_result(self, trial_runner, trial, result) primary = decision # if self.early_stopper.should_kill(trial, result): # decision = TrialScheduler.STOP self._startup_times.add(result["setup_time"]) if result["atoms"] == 1: self._single_atom_iteration_times += [result["time_this_iter_s"]] self._longest_duration = max( self._longest_duration, result["time_total_s"] ) self.allocator.populate_if_needed(self.get_live_trials(trial_runner)) # This is commented out because there is no "safe" way to argue # for this. # if self.time_left < self._deadline / self._reduction_factor: # logger.warning("Since time left is less than reduction factor, " # f"POLICY IS TOP_JOB") # self.allocator.set_policy("TOP_JOB") # We reallocate according to the current decision. This should be # the last change of decision, and will only happen if decision = self.allocator.on_result( trial_runner, trial, decision, execute=False ) if decision == TrialScheduler.CONTINUE: # Check if with the new allocation, there is improvement check(self.allocator.get_proposed_atoms(trial)) if self._improved_progress_after_reallocation(trial): self.allocator.try_execute_update(trial, trial_runner) if self._no_speculation and self.allocator._policy == "NONE": check(primary == decision) return decision
def __init__( self, total_atoms, resource_policy="UNIFORM", scaling_dict=SCALING_MAP["LINEAR"], deadline=np.inf, allocation_grid=None, use_pausing=True, grace_period=1, reduction_factor=4, max_t=100, time_attr="training_iteration", metric="episode_reward_mean", mode="max", _no_speculation=False, _ignore_overhead=False, _no_job_limit=False, _assume_linear=False, _fixed_exploration=False, _exploration_ratio=1.0, ): # Arguments for ablative study self._no_speculation = _no_speculation # stored self._ignore_overhead = _ignore_overhead # stored self._no_job_limit = _no_job_limit # stored self._assume_linear = _assume_linear self._fixed_exploration = _fixed_exploration self._exploration_ratio = _exploration_ratio FIFOScheduler.__init__(self) self.use_pausing = use_pausing self._num_paused = 0 self._num_stopped = 0 self._reduction_factor = reduction_factor self._max_t = max_t self._metric = metric self._time_attr = time_attr if mode == "max": self._metric_op = 1.0 elif mode == "min": self._metric_op = -1.0 if self._no_speculation: self._brackets = [ ASHAv2Bracket( min_t=grace_period, max_t=self._max_t, reduction_factor=self._reduction_factor, s=0, ) ] else: self._brackets = [ _DeadlineBracket( self._reduction_factor, max_t=self._max_t, min_t=grace_period, use_pausing=self.use_pausing, ) ] if self._fixed_exploration: logger.warning( f"FIXED EXPLORATION TIME OF {self._exploration_ratio}" ) if self._fixed_exploration: logger.warning( f"FIXED EXPLORATION TIME OF {self._exploration_ratio}" ) self.grace_period = grace_period self.start_time = time.time() self._deadline = deadline self._deadline_time = deadline + time.time() self._longest_duration = -1 check(self._deadline_time > self.start_time) self.total_atoms = total_atoms self.allocator = DynamicAllocator( self.total_atoms, policy=resource_policy, allocation_grid=allocation_grid, recharge_period=5, metric=self._metric, metric_op=self._metric_op, ) if self._assume_linear: logger.warning("ABLATION: ASSUMING LINEAR SCALING.") scaling_dict = SCALING_MAP["LINEAR"] self.scaling_fn = scaling_function_from_dict(scaling_dict) self._startup_times = set() #: Time it takes for a single iteration self._single_atom_iteration_times = []
def descending_paused(self): for t in self.paused: check(t in self.recorded) return sorted( self.paused, key=lambda t: self.recorded[t], reverse=True )
def terminate(self, done_info): check(bool("terminate" not in self.info)) self.running = False self._done = True self.info["terminate"] = done_info