예제 #1
0
def populate_experiment(config, app_id, run_id, exp_function):
    """Creates a dictionary with the experiment information.

    Args:
        :config: Experiment config object
        :app_id: Application ID
        :run_id: Current experiment run ID
        :exp_function: Name of experiment driver.

    Returns:
        :experiment_json: Dictionary with config info on the experiment.
    """
    try:
        direction = config.direction
    except AttributeError:
        direction = "N/A"
    try:
        opt_key = config.optimization_key
    except AttributeError:
        opt_key = "N/A"
    experiment_json = EnvSing.get_instance().populate_experiment(
        config.name,
        exp_function,
        "MAGGY",
        None,
        config.description,
        app_id,
        direction,
        opt_key,
    )
    exp_ml_id = app_id + "_" + str(run_id)
    experiment_json = EnvSing.get_instance().attach_experiment_xattr(
        exp_ml_id, experiment_json, "INIT")
    return experiment_json
예제 #2
0
 def _exp_startup_callback(self) -> None:
     """Registers the hp config to tensorboard upon experiment startup."""
     tensorboard._register(EnvSing.get_instance().get_logdir(
         self.app_id, self.run_id))
     tensorboard._write_hparams_config(
         EnvSing.get_instance().get_logdir(self.app_id, self.run_id),
         self.config.searchspace,
     )
예제 #3
0
    def json(self) -> str:
        """Exports the experiment's metadata in JSON format.

        :returns: The metadata string.
        """
        user = None
        constants = EnvSing.get_instance().get_constants()
        try:
            if constants.ENV_VARIABLES.HOPSWORKS_USER_ENV_VAR in os.environ:
                user = os.environ[
                    constants.ENV_VARIABLES.HOPSWORKS_USER_ENV_VAR]
        except AttributeError:
            pass

        experiment_json = {
            "project":
            EnvSing.get_instance().project_name(),
            "user":
            user,
            "name":
            self.name,
            "module":
            "maggy",
            "app_id":
            str(self.app_id),
            "start":
            time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime(self.job_start)),
            "memory_per_executor":
            str(self.spark_context._conf.get("spark.executor.memory")),
            "gpus_per_executor":
            str(self.spark_context._conf.get("spark.executor.gpus")),
            "executors":
            self.num_executors,
            "logdir":
            self.log_dir,
            # 'versioned_resources': versioned_resources,
            "description":
            self.description,
            "experiment_type":
            self.controller.name(),
        }

        experiment_json["controller"] = self.controller.name()
        experiment_json["config"] = json.dumps(self.config_to_dict())

        if self.experiment_done:
            experiment_json["status"] = "FINISHED"
            experiment_json["finished"] = time.strftime(
                "%Y-%m-%dT%H:%M:%S", time.localtime(self.job_end))
            experiment_json["duration"] = self.duration
            experiment_json["config"] = json.dumps(self.result["best_config"])
            experiment_json["metric"] = self.result["best_val"]

        else:
            experiment_json["status"] = "RUNNING"

        return json.dumps(experiment_json, default=util.json_default_numpy)
예제 #4
0
    def _final_msg_callback(self, msg: dict) -> None:
        """Final message callback.

        Logs trial results and registers executor as idle.

        :param msg: The final executor message from the message queue.
        """
        trial = self.get_trial(msg["trial_id"])
        logs = msg.get("logs", None)
        if logs is not None:
            with self.log_lock:
                self.executor_logs = self.executor_logs + logs

        # finalize the trial object
        with trial.lock:
            trial.status = Trial.FINALIZED
            trial.final_metric = msg["data"]
            trial.duration = util.seconds_to_milliseconds(time.time() -
                                                          trial.start)

        # move trial to the finalized ones
        self._final_store.append(trial)
        self._trial_store.pop(trial.trial_id)

        # update result dictionary
        self._update_result(trial)
        # keep for later in case tqdm doesn't work
        self.maggy_log = self._update_maggy_log()
        self.log(self.maggy_log)

        EnvSing.get_instance().dump(
            trial.to_json(),
            self.log_dir + "/" + trial.trial_id + "/trial.json",
        )

        # assign new trial
        trial = self.controller_get_next(trial)
        if trial is None:
            self.server.reservations.assign_trial(msg["partition_id"], None)
            self.experiment_done = True
        elif trial == "IDLE":
            self.add_message({
                "type": "IDLE",
                "partition_id": msg["partition_id"],
                "idle_start": time.time(),
            })
            self.server.reservations.assign_trial(msg["partition_id"], None)
        else:
            with trial.lock:
                trial.start = time.time()
                trial.status = Trial.SCHEDULED
                self.server.reservations.assign_trial(msg["partition_id"],
                                                      trial.trial_id)
                self.add_trial(trial)
예제 #5
0
def _exit_handler() -> None:
    """Handles jobs killed by the user."""
    try:
        global RUNNING
        global EXPERIMENT_JSON
        if RUNNING:
            EXPERIMENT_JSON["status"] = "KILLED"
            exp_ml_id = APP_ID + "_" + str(RUN_ID)
            EnvSing.get_instance().attach_experiment_xattr(
                exp_ml_id, EXPERIMENT_JSON, "FULL_UPDATE")
    except Exception as err:
        util.log(err)
예제 #6
0
def _load_hparams(hparams_file):
    """Loads the HParams configuration from a hparams file of a trial."""

    hparams_file_contents = EnvSing.get_instance().load(hparams_file)
    hparams = json.loads(hparams_file_contents)

    return hparams
예제 #7
0
def register_environment(app_id, run_id):
    """Validates IDs and creates an experiment folder in the fs.

    Args:
        :app_id: Application ID
        :run_id: Current experiment run ID

    Returns: (app_id, run_id) with the updated IDs.
    """
    app_id = str(find_spark().sparkContext.applicationId)
    app_id, run_id = validate_ml_id(app_id, run_id)
    set_ml_id(app_id, run_id)
    # Create experiment directory.
    EnvSing.get_instance().create_experiment_dir(app_id, run_id)
    tensorboard._register(EnvSing.get_instance().get_logdir(app_id, run_id))
    return app_id, run_id
예제 #8
0
def _exception_handler(duration: int) -> None:
    """Handles exceptions during execution of an experiment.

    :param duration: Duration of the experiment until exception in milliseconds
    """
    try:
        global RUNNING
        global EXPERIMENT_JSON
        if RUNNING:
            EXPERIMENT_JSON["state"] = "FAILED"
            EXPERIMENT_JSON["duration"] = duration
            exp_ml_id = APP_ID + "_" + str(RUN_ID)
            EnvSing.get_instance().attach_experiment_xattr(
                exp_ml_id, EXPERIMENT_JSON, "FULL_UPDATE")
    except Exception as err:
        util.log(err)
예제 #9
0
    def __init__(
        self, dataset: str, batch_size: int = 1, transform_spec: TransformSpec = None
    ):
        """Initializes a reader depending on the dataset (Petastorm/Parquet).

        :param dataset: Path to the dataset.
        :param batch_size: How many samples per batch to load (default: ``1``).
        :param transform_spec: Petastorm transform spec for data augmentation.
        """
        num_workers = int(os.environ["WORLD_SIZE"])  # Is set at lagom startup.
        rank = int(os.environ["RANK"])
        is_peta_ds = EnvSing.get_instance().exists(
            dataset.rstrip("/") + "/_common_metadata"
        )
        # Make reader only compatible with petastorm dataset.
        ds_type = "Petastorm" if is_peta_ds else "Parquet"
        print(f"{ds_type} dataset detected in folder {dataset}")
        reader_factory = make_reader if is_peta_ds else make_batch_reader
        reader = reader_factory(
            dataset,
            cur_shard=rank,
            shard_count=num_workers,
            transform_spec=TransformSpec(transform_spec),
        )
        super().__init__(reader, batch_size=batch_size)
        self.iterator = None
예제 #10
0
    def log(self, log_msg: str) -> None:
        """Logs a string to the maggy driver log file.

        :param log_msg: The log message.
        """
        msg = datetime.now().isoformat() + ": " + str(log_msg)
        self.log_file_handle.write(EnvSing.get_instance().str_or_byte((msg + "\n")))
예제 #11
0
    def log(self, log_msg, jupyter=False):
        """Logs a message to the executor logfile and executor stderr and
        optionally prints the message in jupyter.

        :param log_msg: Message to log.
        :type log_msg: str
        :param verbose: Print in Jupyter Notebook, defaults to True
        :type verbose: bool, optional
        """
        with self.lock:
            env = EnvSing.get_instance()
            try:
                msg = (datetime.now().isoformat() +
                       " ({0}/{1}): {2} \n").format(self.partition_id,
                                                    self.task_attempt, log_msg)
                if jupyter:
                    jupyter_log = str(self.partition_id) + ": " + log_msg
                    if self.trial_fd:
                        self.trial_fd.write(env.str_or_byte(msg))
                    self.logs = self.logs + jupyter_log + "\n"
                else:
                    self.fd.write(env.str_or_byte(msg))
                    if self.trial_fd:
                        self.trial_fd.write(env.str_or_byte(msg))
                    self.print_executor(msg)
            # Throws ValueError when operating on closed HDFS file object
            # Throws AttributeError when calling file ops on NoneType object
            except (IOError, ValueError, AttributeError) as e:
                self.fd.write(
                    env.str_or_byte(
                        "An error occurred while writing logs: {}".format(e)))
예제 #12
0
 def init_logger(self, trial_log_file):
     """Initializes the trial log file"""
     self.trial_log_file = trial_log_file
     env = EnvSing.get_instance()
     # Open trial log file descriptor
     if not env.exists(self.trial_log_file):
         env.dump("", self.trial_log_file)
     self.trial_fd = env.open_file(self.trial_log_file, flags="w")
예제 #13
0
def _consume_data(config):
    """Load and return the training and test datasets from config file. If the config.dataset and config.test_set are
    strings they are assumed as path, the functions check if the files or directories exists, if they exists then it
    will run the function in config.process_data, with paramneters config.dataset and config.test_set and return the
    result.
    If the config.dataset and cofig.test_set are not strings but anything else (like a List, nparray, tf.data.Dataset)
    they will returned as they are.
    The types of config.dataset and config.test_set have to be the same.


    :param config: the experiment configuration dictionary

    :returns: dataset

    :raises TypeError: if the config.dataset and config.test_set are of different type
    :raises TypeError: if the process_data function is missing or cannot process the data
    :raises FileNotFoundError: in case config.dataset or config.test_set are not found
    """

    dataset_list = config.dataset
    if not isinstance(dataset_list, list):
        raise TypeError(
            "Dataset must be a list, got {}. If you have only 1 set, provide it within a list".format(
                type(dataset_list)
            )
        )

    data_type = dataset_list[0]

    if data_type == str:
        for ds in dataset_list:
            if type(ds) != data_type:
                raise TypeError(
                    "Dataset contains string and other types, "
                    "if a string is included, it must contain all strings."
                )

        env = EnvSing.get_instance()

        for ds in dataset_list:
            if not (env.isdir(ds) or env.exists(ds)):
                raise FileNotFoundError(f"Path {ds} does not exists.")
        try:
            return config.process_data(dataset_list)
        except TypeError:
            raise TypeError(
                (
                    f"process_data function missing in config, "
                    f"please provide a function that takes 1 argument dataset, "
                    f"reads it and "
                    f"returns the transformed dataset as the list before. "
                    f"config: {config}"
                )
            )
    else:  # type is not str (could be anything)
        return config.process_data(dataset_list)
예제 #14
0
def _setup_logging(reporter: Reporter, log_dir: str) -> Tuple[str, str]:
    """Sets up logging directories and files.

    :param reporter: Reporter responsible for logging.
    :param log_dir: Log directory path on the file system.

    :returns: Tuple containing the path of the tensorboard directory
        and the trial log file.
    """
    reporter.set_trial_id(0)
    tb_logdir = log_dir + "/" + "training_logs_" + str(reporter.partition_id)
    trial_log_file = tb_logdir + "/output.log"
    reporter.set_trial_id(0)
    # If trial is repeated, delete trial directory, except log file
    if EnvSing.get_instance().exists(tb_logdir):
        util.clean_dir(tb_logdir, [trial_log_file])
    else:
        EnvSing.get_instance().mkdir(tb_logdir)
    reporter.init_logger(trial_log_file)
    return tb_logdir, trial_log_file
예제 #15
0
    def finalize(self, job_end: float) -> dict:
        """Saves a summary of the experiment to a dict and logs it in the DFS.

        :param job_end: Time of the job end.

        :returns: The experiment summary dict.
        """
        self.job_end = job_end
        self.duration = util.seconds_to_milliseconds(self.job_end -
                                                     self.job_start)
        duration_str = util.time_diff(self.job_start, self.job_end)
        results = self.prep_results(duration_str)
        print(results)
        self.log(results)
        EnvSing.get_instance().dump(
            json.dumps(self.result, default=util.json_default_numpy),
            self.log_dir + "/result.json",
        )
        EnvSing.get_instance().dump(self.json(), self.log_dir + "/maggy.json")
        return self.result_dict
예제 #16
0
def num_executors(sc):
    """
    Get the number of executors configured for Jupyter

    :param sc: The SparkContext to take the executors from.
    :type sc: [SparkContext
    :return: Number of configured executors for Jupyter
    :rtype: int
    """

    return EnvSing.get_instance().get_executors(sc)
예제 #17
0
def clean_dir(clean_dir, keep=[]):
    """Deletes all files in a directory but keeps a few."""
    env = EnvSing.get_instance()

    if not env.isdir(clean_dir):
        raise ValueError(
            "{} is not a directory. Use `hops.hdfs.delete()` to delete single "
            "files.".format(clean_dir))
    for path in env.ls(clean_dir):
        if path not in keep:
            env.delete(path, recursive=True)
예제 #18
0
    def _exp_final_callback(self, job_end: float, _: Any) -> dict:
        """Calculates the average test error from all partitions.

        :param job_end: Time of the job end.
        :param _: Catches additional callback arguments.

        :returns: The result in a dictionary.
        """
        result = {"test result": self.average_metric()}
        exp_ml_id = str(self.app_id) + "_" + str(self.run_id)
        EnvSing.get_instance().attach_experiment_xattr(
            exp_ml_id,
            {"state": "FINISHED", "duration": int(job_end - self.job_start) * 1000},
            "FULL_UPDATE",
        )
        print("Final average test loss: {:.3f}".format(self.average_metric()))
        print(
            "Finished experiment. Total run time: "
            + util.time_diff(self.job_start, job_end)
        )
        return result
예제 #19
0
    def __init__(self, config: LagomConfig, app_id: int, run_id: int):
        """Sets up the RPC server, message queue and logs.

        :param config: Experiment config.
        :param app_id: Maggy application ID.
        :param run_id: Maggy run ID.
        """
        global DRIVER_SECRET
        self.config = config
        self.app_id = app_id
        self.run_id = run_id
        self.name = config.name
        self.description = config.description
        self.spark_context = None
        self.num_executors = util.num_physical_devices()
        self.server_addr = None
        self.hb_interval = config.hb_interval
        self.job_start = None
        DRIVER_SECRET = (DRIVER_SECRET if DRIVER_SECRET else
                         self._generate_secret(self.SECRET_BYTES))
        self._secret = DRIVER_SECRET
        # Logging related initialization
        self._message_q = queue.Queue()
        self.message_callbacks = {}
        self.server = None
        self._register_msg_callbacks()
        self.worker_done = False
        self.executor_logs = ""
        self.log_lock = threading.RLock()
        self.log_dir = EnvSing.get_instance().get_logdir(app_id, run_id)
        log_file = self.log_dir + "/maggy.log"
        # Open File desc for HDFS to log
        if not EnvSing.get_instance().exists(log_file):
            EnvSing.get_instance().dump("", log_file)
        self.log_file_handle = EnvSing.get_instance().open_file(log_file,
                                                                flags="w")
        self.exception = None
        self.result = None
        self.result_dict = {}
        self.main_metric_key = None
예제 #20
0
    def start(self, exp_driver):
        """
        Start listener in a background thread.

        Returns:
            address of the Server as a tuple of (host, port)
        """
        global SERVER_HOST_PORT

        server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        server_sock, SERVER_HOST_PORT = EnvSing.get_instance().connect_host(
            server_sock, SERVER_HOST_PORT, exp_driver
        )

        def _listen(self, sock, driver):
            CONNECTIONS = []
            CONNECTIONS.append(sock)

            while not self.done:
                read_socks, _, _ = select.select(CONNECTIONS, [], [], 1)
                for sock in read_socks:
                    if sock == server_sock:
                        client_sock, client_addr = sock.accept()
                        CONNECTIONS.append(client_sock)
                        _ = client_addr
                    else:
                        try:
                            msg = self.receive(sock)
                            # raise exception if secret does not match
                            # so client socket gets closed
                            if not secrets.compare_digest(
                                msg["secret"], exp_driver._secret
                            ):
                                exp_driver.log(
                                    "SERVER secret: {}".format(exp_driver._secret)
                                )
                                exp_driver.log(
                                    "ERROR: wrong secret {}".format(msg["secret"])
                                )
                                raise Exception

                            self._handle_message(sock, msg, driver)
                        except Exception:
                            sock.close()
                            CONNECTIONS.remove(sock)
            server_sock.close()

        threading.Thread(
            target=_listen, args=(self, server_sock, exp_driver), daemon=True
        ).start()
        return SERVER_HOST_PORT
예제 #21
0
    def _initialize_logger(self, exp_dir):
        """Initialize logger of optimizer

        :param exp_dir: path of experiment directory
        :rtype exp_dir: str
        """
        env = EnvSing.get_instance()
        # configure logger
        self.log_file = exp_dir + "/optimizer.log"
        if not env.exists(self.log_file):
            env.dump("", self.log_file)
        self.fd = env.open_file(self.log_file, flags="w")
        self._log("Initialized Optimizer Logger")
예제 #22
0
def finalize_experiment(
    experiment_json,
    metric,
    app_id,
    run_id,
    state,
    duration,
    logdir,
    best_logdir,
    optimization_key,
):
    EnvSing.get_instance().finalize_experiment(
        experiment_json,
        metric,
        app_id,
        run_id,
        state,
        duration,
        logdir,
        best_logdir,
        optimization_key,
    )
예제 #23
0
def lagom(train_fn: Callable, config: LagomConfig) -> dict:
    """Launches a maggy experiment, which depending on 'config' can either
    be a hyperparameter optimization, an ablation study experiment or distributed
    training. Given a search space, objective and a model training procedure `train_fn`
    (black-box function), an experiment is the whole process of finding the
    best hyperparameter combination in the search space, optimizing the
    black-box function. Currently maggy supports random search and a median
    stopping rule.
    **lagom** is a Swedish word meaning "just the right amount".

    :param train_fn: User defined experiment containing the model training.
    :param config: An experiment configuration. For more information, see config.

    :returns: The experiment results as a dict.
    """
    global APP_ID
    global RUNNING
    global RUN_ID
    job_start = time.time()
    try:
        if RUNNING:
            raise RuntimeError("An experiment is currently running.")
        RUNNING = True
        spark_context = util.find_spark().sparkContext
        APP_ID = str(spark_context.applicationId)
        APP_ID, RUN_ID = util.register_environment(APP_ID, RUN_ID)
        EnvSing.get_instance().set_app_id(APP_ID)
        driver = lagom_driver(config, APP_ID, RUN_ID)
        return driver.run_experiment(train_fn, config)
    except:  # noqa: E722
        _exception_handler(
            util.seconds_to_milliseconds(time.time() - job_start))
        raise
    finally:
        # Clean up spark jobs
        RUN_ID += 1
        RUNNING = False
        util.find_spark().sparkContext.setJobGroup("", "")
예제 #24
0
 def __init__(self, server_addr, partition_id, task_attempt, hb_interval, secret):
     # socket for main thread
     self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
     self.sock.connect(server_addr)
     # socket for heartbeat thread
     self.hb_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
     self.hb_sock.connect(server_addr)
     self.server_addr = server_addr
     self.done = False
     self.client_addr = (
         EnvSing.get_instance().get_ip_address(),
         self.sock.getsockname()[1],
     )
     self.partition_id = partition_id
     self.task_attempt = task_attempt
     self.hb_interval = hb_interval
     self._secret = secret
예제 #25
0
def build_summary_json(logdir):
    """Builds the summary json to be read by the experiments service."""
    combinations = []
    env = EnvSing.get_instance()
    for trial in env.ls(logdir):
        if env.isdir(trial):
            return_file = trial + "/.outputs.json"
            hparams_file = trial + "/.hparams.json"
            if env.exists(return_file) and env.exists(hparams_file):
                metric_arr = env.convert_return_file_to_arr(return_file)
                hparams_dict = _load_hparams(hparams_file)
                combinations.append({
                    "parameters": hparams_dict,
                    "outputs": metric_arr
                })

    return json.dumps({"combinations": combinations},
                      default=json_default_numpy)
예제 #26
0
                def create_tf_dataset(num_epochs, batch_size):
                    conn = EnvSing.get_instance().connect_hsfs(engine="training")
                    fs = conn.get_feature_store()

                    td = fs.get_training_dataset(
                        training_dataset_name, training_dataset_version
                    )

                    feature_names = [f.name for f in td.schema]

                    if ablated_feature is not None:
                        feature_names.remove(ablated_feature)

                    return td.tf_data(
                        target_name=label_name, feature_names=feature_names
                    ).tf_record_dataset(
                        batch_size=batch_size, num_epochs=num_epochs, process=True
                    )
예제 #27
0
    def __init__(self, log_file, partition_id, task_attempt, print_executor):
        self.metric = None
        self.step = -1
        self.lock = threading.RLock()
        self.stop = False
        self.trial_id = None
        self.trial_log_file = None
        self.logs = ""
        self.log_file = log_file
        self.partition_id = partition_id
        self.task_attempt = task_attempt
        self.print_executor = print_executor

        # Open executor log file descriptor
        # This log is for all maggy system related log messages
        env = EnvSing.get_instance()
        if not env.exists(log_file):
            env.dump("", log_file)
        self.fd = env.open_file(log_file, flags="w")
        self.trial_fd = None
예제 #28
0
def handle_return_val(return_val, log_dir, optimization_key, log_file):
    """Handles the return value of the user defined training function."""
    env = EnvSing.get_instance()

    env.upload_file_output(return_val, log_dir)

    # Return type validation
    if not optimization_key:
        raise ValueError("Optimization key cannot be None.")
    if not return_val:
        raise exceptions.ReturnTypeError(optimization_key, return_val)
    if not isinstance(return_val, constants.USER_FCT.RETURN_TYPES):
        raise exceptions.ReturnTypeError(optimization_key, return_val)
    if isinstance(return_val, dict) and optimization_key not in return_val:
        raise KeyError(
            "Returned dictionary does not contain optimization key with the "
            "provided name: {}".format(optimization_key))

    # validate that optimization metric is numeric
    if isinstance(return_val, dict):
        opt_val = return_val[optimization_key]
    else:
        opt_val = return_val
        return_val = {optimization_key: opt_val}

    if not isinstance(opt_val, constants.USER_FCT.NUMERIC_TYPES):
        raise exceptions.MetricTypeError(optimization_key, opt_val)

    # for key, value in return_val.items():
    #    return_val[key] = value if isinstance(value, str) else str(value)

    return_val["log"] = log_file.replace(env.project_path(), "")

    return_file = log_dir + "/.outputs.json"
    env.dump(json.dumps(return_val, default=json_default_numpy), return_file)

    metric_file = log_dir + "/.metric"
    env.dump(json.dumps(opt_val, default=json_default_numpy), metric_file)

    return opt_val
예제 #29
0
    def wrapper_function(_: Any) -> None:
        """Patched function from dist_executor_fn factory.

        :param _: Necessary catch for the iterator given by Spark to the
        function upon foreach calls. Can safely be disregarded.
        """
        EnvSing.get_instance().set_ml_id(app_id, run_id)
        partition_id, _ = util.get_partition_attempt_id()
        client = Client(server_addr, partition_id, 0, hb_interval, secret)
        log_file = log_dir + "/executor_" + str(partition_id) + ".log"

        builtin_print = __builtin__.print
        reporter = Reporter(log_file, partition_id, 0, builtin_print)

        def maggy_print(*args, **kwargs):
            builtin_print(*args, **kwargs)
            reporter.log(" ".join(str(x) for x in args), True)

        __builtin__.print = maggy_print

        try:
            _register_with_servers(client, reporter, partition_id)
            tb_logdir, trial_log_file = _setup_logging(reporter, log_dir)
            reporter.log("Awaiting worker reservations.", True)
            client.await_reservations()
            reporter.log("Reservations complete, configuring PyTorch.", True)
            master_config = client.get_exec_config()[0]
            if not master_config:
                reporter.log("RuntimeError: PyTorch registration failed.",
                             True)
                raise RuntimeError("PyTorch registration failed.")
            addr, port = master_config["host_port"].split(":")
            torch_config = {
                "MASTER_ADDR": addr,
                "MASTER_PORT": port,
                "WORLD_SIZE": str(master_config["num_executors"]),
                "RANK": str(partition_id),
                "LOCAL_RANK": str(0),  # DeepSpeed requires local rank.
                "NCCL_BLOCKING_WAIT": "1",
                "NCCL_DEBUG": "INFO",
            }
            tensorboard._register(tb_logdir)
            reporter.log(f"Torch config is {torch_config}", True)

            _setup_torch_env(torch_config)
            _sanitize_config(config)
            _init_cluster(timeout=60, random_seed=0)
            module = _wrap_module_dispatcher(config)
            _monkey_patch_pytorch(config.zero_lvl)

            reporter.log("Starting distributed training.", True)
            sig = inspect.signature(train_fn)
            if sig.parameters.get("reporter", None):
                retval = train_fn(
                    module=module,
                    hparams=config.hparams,
                    train_set=config.train_set,
                    test_set=config.test_set,
                    reporter=reporter,
                )
            else:
                retval = train_fn(
                    module=module,
                    hparams=config.hparams,
                    train_set=config.train_set,
                    test_set=config.test_set,
                )

            retval = util.handle_return_val(retval, tb_logdir, "Metric",
                                            trial_log_file)
            dist.barrier(
            )  # Don't exit until all executors are done (else NCCL crashes)
            reporter.log("Finished distributed training.", True)
            client.finalize_metric(retval, reporter)
        except:  # noqa: E722
            reporter.log(traceback.format_exc())
            raise
        finally:
            reporter.close_logger()
            client.stop()
            client.close()
예제 #30
0
    def spark_wrapper_function(_: Any) -> None:
        """Patched function from tf_dist_executor_fn factory.

        :param _: Necessary catch for the iterator given by Spark to the
        function upon foreach calls. Can safely be disregarded.
        """
        EnvSing.get_instance().set_ml_id(app_id, run_id)
        partition_id, _ = util.get_partition_attempt_id()
        client = EnvSing.get_instance().get_client(
            server_addr,
            partition_id,
            hb_interval,
            secret,
            socket.socket(socket.AF_INET, socket.SOCK_STREAM),
        )
        log_file = log_dir + "/executor_" + str(partition_id) + ".log"

        reporter = Reporter(log_file, partition_id, 0, __builtin__.print)
        builtin_print = __builtin__.print
        _setup_logging(reporter, log_dir)

        def maggy_print(*args, **kwargs):
            builtin_print(*args, **kwargs)
            reporter.log(" ".join(str(x) for x in args), True)

        __builtin__.print = maggy_print

        try:
            host = EnvSing.get_instance().get_ip_address()

            tmp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            tmp_socket.bind(("", 0))
            port = tmp_socket.getsockname()[1] + 1

            host_port = host + ":" + str(port)

            _register_with_servers(client, reporter, partition_id)
            tb_logdir, trial_log_file = _setup_logging(reporter, log_dir)
            tensorboard._register(tb_logdir)

            reporter.log("Awaiting worker reservations.")
            client.await_reservations()

            reservations = client.get_message("RESERVATIONS")
            reporter.log(reservations)
            reporter.log(host_port)
            reporter.log("Reservations complete, configuring Tensorflow.")

            if not reservations:
                reporter.log("Tensorflow registration failed, exiting from all tasks.")
                return

            workers_host_port = []

            for i in list(reservations["cluster"]):
                if len(reservations["cluster"][i]) > 0:
                    workers_host_port.append(reservations["cluster"][i][0])

            is_chief = False
            task_index = find_index(host_port, reservations)
            tf_config = reservations
            if task_index == -1:
                tf_config["task"] = {"type": "chief", "index": 0}
                is_chief = True
            else:
                tf_config["task"] = {"type": "worker", "index": task_index}

            last_worker_index = len(reservations["cluster"]["worker"]) - 1
            if not last_worker_index < 0:
                evaluator_node = reservations["cluster"]["worker"][last_worker_index]
                reservations["cluster"]["evaluator"] = [evaluator_node]
                del reservations["cluster"]["worker"][last_worker_index]
                if evaluator_node == host_port:
                    tf_config["task"] = {"type": "evaluator", "index": 0}

            reporter.log(f"Tensorflow config is {tf_config}")

            _setup_tf_config(tf_config)

            strategy = tf.distribute.MultiWorkerMirroredStrategy
            model = _wrap_model(config, strategy, is_chief)

            if config.dataset is not None and config.process_data is not None:
                config.dataset = _consume_data(config)

            reporter.log(f"index of slice {partition_id}")
            reporter.log("Starting distributed training.")
            sig = inspect.signature(train_fn)

            kwargs = {}
            if sig.parameters.get("model", None):
                kwargs["model"] = model
            if sig.parameters.get("dataset", None):
                kwargs["dataset"] = config.dataset
            if sig.parameters.get("hparams", None):
                kwargs["hparams"] = config.hparams

            if sig.parameters.get("reporter", None):
                kwargs["reporter"] = reporter
                retval = train_fn(**kwargs)
            else:
                retval = train_fn(**kwargs)

            # Set retval to work with util.handle_return_value,
            # if there is more than 1 metrics, retval will be a list and
            # retval[0] will contain the final loss
            retval_list = []
            if isinstance(retval, dict):
                for item in retval.items():
                    retval_list.append(item[1])
                retval = retval_list
            retval = {"Metric": retval[0] if isinstance(retval, list) else retval}
            retval = util.handle_return_val(retval, tb_logdir, "Metric", trial_log_file)
            reporter.log("Finished distributed training.")
            client.finalize_metric(retval, reporter)
        except:  # noqa: E722
            reporter.log(traceback.format_exc())
            raise
        finally:
            reporter.close_logger()
            client.stop()
            client.close()