Пример #1
0
    def _metric_callback(self, resp: dict, msg: dict,
                         exp_driver: Driver) -> None:
        """Metric message callback.

        Confirms heartbeat messages from the clients and adds logs to the driver.
        """
        exp_driver.add_message(msg)
        resp["type"] = "OK"
Пример #2
0
    def _final_callback(self, resp: dict, msg: dict,
                        exp_driver: Driver) -> None:
        """Final message callback.

        Adds final results to the message queue.
        """
        resp["type"] = "OK"
        exp_driver.add_message(msg)
Пример #3
0
    def _register_callback(self, resp: dict, msg: dict,
                           exp_driver: Driver) -> None:
        """Register message callback.

        Saves workers connection metadata for initialization of distributed
        backend.
        """
        self.reservations.add(msg["data"])
        exp_driver.add_message(msg)
        resp["type"] = "OK"
Пример #4
0
    def _final_callback(self, resp: dict, msg: dict,
                        exp_driver: Driver) -> None:
        """Final message callback.

        Resets the reservation to avoid sending the trial again.
        """
        self.reservations.assign_trial(msg["partition_id"], None)
        resp["type"] = "OK"
        # add metric msg to the exp driver queue
        exp_driver.add_message(msg)
Пример #5
0
 def _get_callback(self, resp: dict, msg: dict, exp_driver: Driver) -> None:
     # lookup reservation to find assigned trial
     trial_id = self.reservations.get_assigned_trial(msg["partition_id"])
     # trial_id needs to be none because experiment_done can be true but
     # the assigned trial might not be finalized yet
     if exp_driver.experiment_done and trial_id is None:
         resp["type"] = "GSTOP"
     else:
         resp["type"] = "TRIAL"
     resp["trial_id"] = trial_id
     # retrieve trial information
     if trial_id is not None:
         resp["data"] = exp_driver.get_trial(trial_id).params
         exp_driver.get_trial(trial_id).status = Trial.RUNNING
     else:
         resp["data"] = None
Пример #6
0
    def _metric_callback(self, resp: dict, msg: dict,
                         exp_driver: Driver) -> None:
        """Metric message callback.

        Determines if a trial should be stopped or not.
        """
        exp_driver.add_message(msg)
        if msg["trial_id"] is None:
            resp["type"] = "OK"
        elif msg["trial_id"] is not None and msg.get("data", None) is None:
            resp["type"] = "OK"
        else:
            # lookup executor reservation to find assigned trial
            # get early stopping flag, should be False for ablation
            flag = exp_driver.get_trial(msg["trial_id"]).get_early_stop()
            resp["type"] = "STOP" if flag else "OK"
Пример #7
0
    def _log_callback(self, resp: dict, _: Any, exp_driver: Driver) -> None:
        """Log message callback.

        Copies logs from the driver and returns them.
        """
        _, log = exp_driver.get_logs()
        resp["type"] = "OK"
        resp["ex_logs"] = log if log else None
        resp["num_trials"] = 1
        resp["to_date"] = 0
        resp["stopped"] = False
        resp["metric"] = "N/A"
Пример #8
0
    def _log_callback(self, resp: dict, _: Any, exp_driver: Driver) -> None:
        """Log message callback.

        Copies logs from the driver and returns them.
        """
        # get data from experiment driver
        result, log = exp_driver.get_logs()
        resp["type"] = "OK"
        resp["ex_logs"] = log if log else None
        resp["num_trials"] = exp_driver.num_trials
        resp["to_date"] = result["num_trials"]
        resp["stopped"] = result["early_stopped"]
        resp["metric"] = result["best_val"]
Пример #9
0
    def _register_callback(self, resp: dict, msg: dict,
                           exp_driver: Driver) -> None:
        """Register message callback.

        Checks if the executor was registered before and reassignes lost trial,
        otherwise assignes a new trial to the executor.
        """
        lost_trial = self.reservations.get_assigned_trial(msg["partition_id"])
        if lost_trial is not None:
            # the trial or executor must have failed
            exp_driver.get_trial(lost_trial).status = Trial.ERROR
            # add a blacklist message to the worker queue
            fail_msg = {
                "partition_id": msg["partition_id"],
                "type": "BLACK",
                "trial_id": lost_trial,
            }
            self.reservations.add(msg["data"])
            exp_driver.add_message(fail_msg)
        else:
            # else add regular registration msg to queue
            self.reservations.add(msg["data"])
            exp_driver.add_message(msg)
        resp["type"] = "OK"