def _metric_callback(self, resp: dict, msg: dict, exp_driver: Driver) -> None: """Metric message callback. Confirms heartbeat messages from the clients and adds logs to the driver. """ exp_driver.add_message(msg) resp["type"] = "OK"
def _final_callback(self, resp: dict, msg: dict, exp_driver: Driver) -> None: """Final message callback. Adds final results to the message queue. """ resp["type"] = "OK" exp_driver.add_message(msg)
def _register_callback(self, resp: dict, msg: dict, exp_driver: Driver) -> None: """Register message callback. Saves workers connection metadata for initialization of distributed backend. """ self.reservations.add(msg["data"]) exp_driver.add_message(msg) resp["type"] = "OK"
def _final_callback(self, resp: dict, msg: dict, exp_driver: Driver) -> None: """Final message callback. Resets the reservation to avoid sending the trial again. """ self.reservations.assign_trial(msg["partition_id"], None) resp["type"] = "OK" # add metric msg to the exp driver queue exp_driver.add_message(msg)
def _get_callback(self, resp: dict, msg: dict, exp_driver: Driver) -> None: # lookup reservation to find assigned trial trial_id = self.reservations.get_assigned_trial(msg["partition_id"]) # trial_id needs to be none because experiment_done can be true but # the assigned trial might not be finalized yet if exp_driver.experiment_done and trial_id is None: resp["type"] = "GSTOP" else: resp["type"] = "TRIAL" resp["trial_id"] = trial_id # retrieve trial information if trial_id is not None: resp["data"] = exp_driver.get_trial(trial_id).params exp_driver.get_trial(trial_id).status = Trial.RUNNING else: resp["data"] = None
def _metric_callback(self, resp: dict, msg: dict, exp_driver: Driver) -> None: """Metric message callback. Determines if a trial should be stopped or not. """ exp_driver.add_message(msg) if msg["trial_id"] is None: resp["type"] = "OK" elif msg["trial_id"] is not None and msg.get("data", None) is None: resp["type"] = "OK" else: # lookup executor reservation to find assigned trial # get early stopping flag, should be False for ablation flag = exp_driver.get_trial(msg["trial_id"]).get_early_stop() resp["type"] = "STOP" if flag else "OK"
def _log_callback(self, resp: dict, _: Any, exp_driver: Driver) -> None: """Log message callback. Copies logs from the driver and returns them. """ _, log = exp_driver.get_logs() resp["type"] = "OK" resp["ex_logs"] = log if log else None resp["num_trials"] = 1 resp["to_date"] = 0 resp["stopped"] = False resp["metric"] = "N/A"
def _log_callback(self, resp: dict, _: Any, exp_driver: Driver) -> None: """Log message callback. Copies logs from the driver and returns them. """ # get data from experiment driver result, log = exp_driver.get_logs() resp["type"] = "OK" resp["ex_logs"] = log if log else None resp["num_trials"] = exp_driver.num_trials resp["to_date"] = result["num_trials"] resp["stopped"] = result["early_stopped"] resp["metric"] = result["best_val"]
def _register_callback(self, resp: dict, msg: dict, exp_driver: Driver) -> None: """Register message callback. Checks if the executor was registered before and reassignes lost trial, otherwise assignes a new trial to the executor. """ lost_trial = self.reservations.get_assigned_trial(msg["partition_id"]) if lost_trial is not None: # the trial or executor must have failed exp_driver.get_trial(lost_trial).status = Trial.ERROR # add a blacklist message to the worker queue fail_msg = { "partition_id": msg["partition_id"], "type": "BLACK", "trial_id": lost_trial, } self.reservations.add(msg["data"]) exp_driver.add_message(fail_msg) else: # else add regular registration msg to queue self.reservations.add(msg["data"]) exp_driver.add_message(msg) resp["type"] = "OK"