def populate_experiment(config, app_id, run_id, exp_function): """Creates a dictionary with the experiment information. Args: :config: Experiment config object :app_id: Application ID :run_id: Current experiment run ID :exp_function: Name of experiment driver. Returns: :experiment_json: Dictionary with config info on the experiment. """ try: direction = config.direction except AttributeError: direction = "N/A" try: opt_key = config.optimization_key except AttributeError: opt_key = "N/A" experiment_json = EnvSing.get_instance().populate_experiment( config.name, exp_function, "MAGGY", None, config.description, app_id, direction, opt_key, ) exp_ml_id = app_id + "_" + str(run_id) experiment_json = EnvSing.get_instance().attach_experiment_xattr( exp_ml_id, experiment_json, "INIT") return experiment_json
def _exp_startup_callback(self) -> None: """Registers the hp config to tensorboard upon experiment startup.""" tensorboard._register(EnvSing.get_instance().get_logdir( self.app_id, self.run_id)) tensorboard._write_hparams_config( EnvSing.get_instance().get_logdir(self.app_id, self.run_id), self.config.searchspace, )
def json(self) -> str: """Exports the experiment's metadata in JSON format. :returns: The metadata string. """ user = None constants = EnvSing.get_instance().get_constants() try: if constants.ENV_VARIABLES.HOPSWORKS_USER_ENV_VAR in os.environ: user = os.environ[ constants.ENV_VARIABLES.HOPSWORKS_USER_ENV_VAR] except AttributeError: pass experiment_json = { "project": EnvSing.get_instance().project_name(), "user": user, "name": self.name, "module": "maggy", "app_id": str(self.app_id), "start": time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime(self.job_start)), "memory_per_executor": str(self.spark_context._conf.get("spark.executor.memory")), "gpus_per_executor": str(self.spark_context._conf.get("spark.executor.gpus")), "executors": self.num_executors, "logdir": self.log_dir, # 'versioned_resources': versioned_resources, "description": self.description, "experiment_type": self.controller.name(), } experiment_json["controller"] = self.controller.name() experiment_json["config"] = json.dumps(self.config_to_dict()) if self.experiment_done: experiment_json["status"] = "FINISHED" experiment_json["finished"] = time.strftime( "%Y-%m-%dT%H:%M:%S", time.localtime(self.job_end)) experiment_json["duration"] = self.duration experiment_json["config"] = json.dumps(self.result["best_config"]) experiment_json["metric"] = self.result["best_val"] else: experiment_json["status"] = "RUNNING" return json.dumps(experiment_json, default=util.json_default_numpy)
def _final_msg_callback(self, msg: dict) -> None: """Final message callback. Logs trial results and registers executor as idle. :param msg: The final executor message from the message queue. """ trial = self.get_trial(msg["trial_id"]) logs = msg.get("logs", None) if logs is not None: with self.log_lock: self.executor_logs = self.executor_logs + logs # finalize the trial object with trial.lock: trial.status = Trial.FINALIZED trial.final_metric = msg["data"] trial.duration = util.seconds_to_milliseconds(time.time() - trial.start) # move trial to the finalized ones self._final_store.append(trial) self._trial_store.pop(trial.trial_id) # update result dictionary self._update_result(trial) # keep for later in case tqdm doesn't work self.maggy_log = self._update_maggy_log() self.log(self.maggy_log) EnvSing.get_instance().dump( trial.to_json(), self.log_dir + "/" + trial.trial_id + "/trial.json", ) # assign new trial trial = self.controller_get_next(trial) if trial is None: self.server.reservations.assign_trial(msg["partition_id"], None) self.experiment_done = True elif trial == "IDLE": self.add_message({ "type": "IDLE", "partition_id": msg["partition_id"], "idle_start": time.time(), }) self.server.reservations.assign_trial(msg["partition_id"], None) else: with trial.lock: trial.start = time.time() trial.status = Trial.SCHEDULED self.server.reservations.assign_trial(msg["partition_id"], trial.trial_id) self.add_trial(trial)
def _exit_handler() -> None: """Handles jobs killed by the user.""" try: global RUNNING global EXPERIMENT_JSON if RUNNING: EXPERIMENT_JSON["status"] = "KILLED" exp_ml_id = APP_ID + "_" + str(RUN_ID) EnvSing.get_instance().attach_experiment_xattr( exp_ml_id, EXPERIMENT_JSON, "FULL_UPDATE") except Exception as err: util.log(err)
def _load_hparams(hparams_file): """Loads the HParams configuration from a hparams file of a trial.""" hparams_file_contents = EnvSing.get_instance().load(hparams_file) hparams = json.loads(hparams_file_contents) return hparams
def register_environment(app_id, run_id): """Validates IDs and creates an experiment folder in the fs. Args: :app_id: Application ID :run_id: Current experiment run ID Returns: (app_id, run_id) with the updated IDs. """ app_id = str(find_spark().sparkContext.applicationId) app_id, run_id = validate_ml_id(app_id, run_id) set_ml_id(app_id, run_id) # Create experiment directory. EnvSing.get_instance().create_experiment_dir(app_id, run_id) tensorboard._register(EnvSing.get_instance().get_logdir(app_id, run_id)) return app_id, run_id
def _exception_handler(duration: int) -> None: """Handles exceptions during execution of an experiment. :param duration: Duration of the experiment until exception in milliseconds """ try: global RUNNING global EXPERIMENT_JSON if RUNNING: EXPERIMENT_JSON["state"] = "FAILED" EXPERIMENT_JSON["duration"] = duration exp_ml_id = APP_ID + "_" + str(RUN_ID) EnvSing.get_instance().attach_experiment_xattr( exp_ml_id, EXPERIMENT_JSON, "FULL_UPDATE") except Exception as err: util.log(err)
def __init__( self, dataset: str, batch_size: int = 1, transform_spec: TransformSpec = None ): """Initializes a reader depending on the dataset (Petastorm/Parquet). :param dataset: Path to the dataset. :param batch_size: How many samples per batch to load (default: ``1``). :param transform_spec: Petastorm transform spec for data augmentation. """ num_workers = int(os.environ["WORLD_SIZE"]) # Is set at lagom startup. rank = int(os.environ["RANK"]) is_peta_ds = EnvSing.get_instance().exists( dataset.rstrip("/") + "/_common_metadata" ) # Make reader only compatible with petastorm dataset. ds_type = "Petastorm" if is_peta_ds else "Parquet" print(f"{ds_type} dataset detected in folder {dataset}") reader_factory = make_reader if is_peta_ds else make_batch_reader reader = reader_factory( dataset, cur_shard=rank, shard_count=num_workers, transform_spec=TransformSpec(transform_spec), ) super().__init__(reader, batch_size=batch_size) self.iterator = None
def log(self, log_msg: str) -> None: """Logs a string to the maggy driver log file. :param log_msg: The log message. """ msg = datetime.now().isoformat() + ": " + str(log_msg) self.log_file_handle.write(EnvSing.get_instance().str_or_byte((msg + "\n")))
def log(self, log_msg, jupyter=False): """Logs a message to the executor logfile and executor stderr and optionally prints the message in jupyter. :param log_msg: Message to log. :type log_msg: str :param verbose: Print in Jupyter Notebook, defaults to True :type verbose: bool, optional """ with self.lock: env = EnvSing.get_instance() try: msg = (datetime.now().isoformat() + " ({0}/{1}): {2} \n").format(self.partition_id, self.task_attempt, log_msg) if jupyter: jupyter_log = str(self.partition_id) + ": " + log_msg if self.trial_fd: self.trial_fd.write(env.str_or_byte(msg)) self.logs = self.logs + jupyter_log + "\n" else: self.fd.write(env.str_or_byte(msg)) if self.trial_fd: self.trial_fd.write(env.str_or_byte(msg)) self.print_executor(msg) # Throws ValueError when operating on closed HDFS file object # Throws AttributeError when calling file ops on NoneType object except (IOError, ValueError, AttributeError) as e: self.fd.write( env.str_or_byte( "An error occurred while writing logs: {}".format(e)))
def init_logger(self, trial_log_file): """Initializes the trial log file""" self.trial_log_file = trial_log_file env = EnvSing.get_instance() # Open trial log file descriptor if not env.exists(self.trial_log_file): env.dump("", self.trial_log_file) self.trial_fd = env.open_file(self.trial_log_file, flags="w")
def _consume_data(config): """Load and return the training and test datasets from config file. If the config.dataset and config.test_set are strings they are assumed as path, the functions check if the files or directories exists, if they exists then it will run the function in config.process_data, with paramneters config.dataset and config.test_set and return the result. If the config.dataset and cofig.test_set are not strings but anything else (like a List, nparray, tf.data.Dataset) they will returned as they are. The types of config.dataset and config.test_set have to be the same. :param config: the experiment configuration dictionary :returns: dataset :raises TypeError: if the config.dataset and config.test_set are of different type :raises TypeError: if the process_data function is missing or cannot process the data :raises FileNotFoundError: in case config.dataset or config.test_set are not found """ dataset_list = config.dataset if not isinstance(dataset_list, list): raise TypeError( "Dataset must be a list, got {}. If you have only 1 set, provide it within a list".format( type(dataset_list) ) ) data_type = dataset_list[0] if data_type == str: for ds in dataset_list: if type(ds) != data_type: raise TypeError( "Dataset contains string and other types, " "if a string is included, it must contain all strings." ) env = EnvSing.get_instance() for ds in dataset_list: if not (env.isdir(ds) or env.exists(ds)): raise FileNotFoundError(f"Path {ds} does not exists.") try: return config.process_data(dataset_list) except TypeError: raise TypeError( ( f"process_data function missing in config, " f"please provide a function that takes 1 argument dataset, " f"reads it and " f"returns the transformed dataset as the list before. " f"config: {config}" ) ) else: # type is not str (could be anything) return config.process_data(dataset_list)
def _setup_logging(reporter: Reporter, log_dir: str) -> Tuple[str, str]: """Sets up logging directories and files. :param reporter: Reporter responsible for logging. :param log_dir: Log directory path on the file system. :returns: Tuple containing the path of the tensorboard directory and the trial log file. """ reporter.set_trial_id(0) tb_logdir = log_dir + "/" + "training_logs_" + str(reporter.partition_id) trial_log_file = tb_logdir + "/output.log" reporter.set_trial_id(0) # If trial is repeated, delete trial directory, except log file if EnvSing.get_instance().exists(tb_logdir): util.clean_dir(tb_logdir, [trial_log_file]) else: EnvSing.get_instance().mkdir(tb_logdir) reporter.init_logger(trial_log_file) return tb_logdir, trial_log_file
def finalize(self, job_end: float) -> dict: """Saves a summary of the experiment to a dict and logs it in the DFS. :param job_end: Time of the job end. :returns: The experiment summary dict. """ self.job_end = job_end self.duration = util.seconds_to_milliseconds(self.job_end - self.job_start) duration_str = util.time_diff(self.job_start, self.job_end) results = self.prep_results(duration_str) print(results) self.log(results) EnvSing.get_instance().dump( json.dumps(self.result, default=util.json_default_numpy), self.log_dir + "/result.json", ) EnvSing.get_instance().dump(self.json(), self.log_dir + "/maggy.json") return self.result_dict
def num_executors(sc): """ Get the number of executors configured for Jupyter :param sc: The SparkContext to take the executors from. :type sc: [SparkContext :return: Number of configured executors for Jupyter :rtype: int """ return EnvSing.get_instance().get_executors(sc)
def clean_dir(clean_dir, keep=[]): """Deletes all files in a directory but keeps a few.""" env = EnvSing.get_instance() if not env.isdir(clean_dir): raise ValueError( "{} is not a directory. Use `hops.hdfs.delete()` to delete single " "files.".format(clean_dir)) for path in env.ls(clean_dir): if path not in keep: env.delete(path, recursive=True)
def _exp_final_callback(self, job_end: float, _: Any) -> dict: """Calculates the average test error from all partitions. :param job_end: Time of the job end. :param _: Catches additional callback arguments. :returns: The result in a dictionary. """ result = {"test result": self.average_metric()} exp_ml_id = str(self.app_id) + "_" + str(self.run_id) EnvSing.get_instance().attach_experiment_xattr( exp_ml_id, {"state": "FINISHED", "duration": int(job_end - self.job_start) * 1000}, "FULL_UPDATE", ) print("Final average test loss: {:.3f}".format(self.average_metric())) print( "Finished experiment. Total run time: " + util.time_diff(self.job_start, job_end) ) return result
def __init__(self, config: LagomConfig, app_id: int, run_id: int): """Sets up the RPC server, message queue and logs. :param config: Experiment config. :param app_id: Maggy application ID. :param run_id: Maggy run ID. """ global DRIVER_SECRET self.config = config self.app_id = app_id self.run_id = run_id self.name = config.name self.description = config.description self.spark_context = None self.num_executors = util.num_physical_devices() self.server_addr = None self.hb_interval = config.hb_interval self.job_start = None DRIVER_SECRET = (DRIVER_SECRET if DRIVER_SECRET else self._generate_secret(self.SECRET_BYTES)) self._secret = DRIVER_SECRET # Logging related initialization self._message_q = queue.Queue() self.message_callbacks = {} self.server = None self._register_msg_callbacks() self.worker_done = False self.executor_logs = "" self.log_lock = threading.RLock() self.log_dir = EnvSing.get_instance().get_logdir(app_id, run_id) log_file = self.log_dir + "/maggy.log" # Open File desc for HDFS to log if not EnvSing.get_instance().exists(log_file): EnvSing.get_instance().dump("", log_file) self.log_file_handle = EnvSing.get_instance().open_file(log_file, flags="w") self.exception = None self.result = None self.result_dict = {} self.main_metric_key = None
def start(self, exp_driver): """ Start listener in a background thread. Returns: address of the Server as a tuple of (host, port) """ global SERVER_HOST_PORT server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) server_sock, SERVER_HOST_PORT = EnvSing.get_instance().connect_host( server_sock, SERVER_HOST_PORT, exp_driver ) def _listen(self, sock, driver): CONNECTIONS = [] CONNECTIONS.append(sock) while not self.done: read_socks, _, _ = select.select(CONNECTIONS, [], [], 1) for sock in read_socks: if sock == server_sock: client_sock, client_addr = sock.accept() CONNECTIONS.append(client_sock) _ = client_addr else: try: msg = self.receive(sock) # raise exception if secret does not match # so client socket gets closed if not secrets.compare_digest( msg["secret"], exp_driver._secret ): exp_driver.log( "SERVER secret: {}".format(exp_driver._secret) ) exp_driver.log( "ERROR: wrong secret {}".format(msg["secret"]) ) raise Exception self._handle_message(sock, msg, driver) except Exception: sock.close() CONNECTIONS.remove(sock) server_sock.close() threading.Thread( target=_listen, args=(self, server_sock, exp_driver), daemon=True ).start() return SERVER_HOST_PORT
def _initialize_logger(self, exp_dir): """Initialize logger of optimizer :param exp_dir: path of experiment directory :rtype exp_dir: str """ env = EnvSing.get_instance() # configure logger self.log_file = exp_dir + "/optimizer.log" if not env.exists(self.log_file): env.dump("", self.log_file) self.fd = env.open_file(self.log_file, flags="w") self._log("Initialized Optimizer Logger")
def finalize_experiment( experiment_json, metric, app_id, run_id, state, duration, logdir, best_logdir, optimization_key, ): EnvSing.get_instance().finalize_experiment( experiment_json, metric, app_id, run_id, state, duration, logdir, best_logdir, optimization_key, )
def lagom(train_fn: Callable, config: LagomConfig) -> dict: """Launches a maggy experiment, which depending on 'config' can either be a hyperparameter optimization, an ablation study experiment or distributed training. Given a search space, objective and a model training procedure `train_fn` (black-box function), an experiment is the whole process of finding the best hyperparameter combination in the search space, optimizing the black-box function. Currently maggy supports random search and a median stopping rule. **lagom** is a Swedish word meaning "just the right amount". :param train_fn: User defined experiment containing the model training. :param config: An experiment configuration. For more information, see config. :returns: The experiment results as a dict. """ global APP_ID global RUNNING global RUN_ID job_start = time.time() try: if RUNNING: raise RuntimeError("An experiment is currently running.") RUNNING = True spark_context = util.find_spark().sparkContext APP_ID = str(spark_context.applicationId) APP_ID, RUN_ID = util.register_environment(APP_ID, RUN_ID) EnvSing.get_instance().set_app_id(APP_ID) driver = lagom_driver(config, APP_ID, RUN_ID) return driver.run_experiment(train_fn, config) except: # noqa: E722 _exception_handler( util.seconds_to_milliseconds(time.time() - job_start)) raise finally: # Clean up spark jobs RUN_ID += 1 RUNNING = False util.find_spark().sparkContext.setJobGroup("", "")
def __init__(self, server_addr, partition_id, task_attempt, hb_interval, secret): # socket for main thread self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.sock.connect(server_addr) # socket for heartbeat thread self.hb_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.hb_sock.connect(server_addr) self.server_addr = server_addr self.done = False self.client_addr = ( EnvSing.get_instance().get_ip_address(), self.sock.getsockname()[1], ) self.partition_id = partition_id self.task_attempt = task_attempt self.hb_interval = hb_interval self._secret = secret
def build_summary_json(logdir): """Builds the summary json to be read by the experiments service.""" combinations = [] env = EnvSing.get_instance() for trial in env.ls(logdir): if env.isdir(trial): return_file = trial + "/.outputs.json" hparams_file = trial + "/.hparams.json" if env.exists(return_file) and env.exists(hparams_file): metric_arr = env.convert_return_file_to_arr(return_file) hparams_dict = _load_hparams(hparams_file) combinations.append({ "parameters": hparams_dict, "outputs": metric_arr }) return json.dumps({"combinations": combinations}, default=json_default_numpy)
def create_tf_dataset(num_epochs, batch_size): conn = EnvSing.get_instance().connect_hsfs(engine="training") fs = conn.get_feature_store() td = fs.get_training_dataset( training_dataset_name, training_dataset_version ) feature_names = [f.name for f in td.schema] if ablated_feature is not None: feature_names.remove(ablated_feature) return td.tf_data( target_name=label_name, feature_names=feature_names ).tf_record_dataset( batch_size=batch_size, num_epochs=num_epochs, process=True )
def __init__(self, log_file, partition_id, task_attempt, print_executor): self.metric = None self.step = -1 self.lock = threading.RLock() self.stop = False self.trial_id = None self.trial_log_file = None self.logs = "" self.log_file = log_file self.partition_id = partition_id self.task_attempt = task_attempt self.print_executor = print_executor # Open executor log file descriptor # This log is for all maggy system related log messages env = EnvSing.get_instance() if not env.exists(log_file): env.dump("", log_file) self.fd = env.open_file(log_file, flags="w") self.trial_fd = None
def handle_return_val(return_val, log_dir, optimization_key, log_file): """Handles the return value of the user defined training function.""" env = EnvSing.get_instance() env.upload_file_output(return_val, log_dir) # Return type validation if not optimization_key: raise ValueError("Optimization key cannot be None.") if not return_val: raise exceptions.ReturnTypeError(optimization_key, return_val) if not isinstance(return_val, constants.USER_FCT.RETURN_TYPES): raise exceptions.ReturnTypeError(optimization_key, return_val) if isinstance(return_val, dict) and optimization_key not in return_val: raise KeyError( "Returned dictionary does not contain optimization key with the " "provided name: {}".format(optimization_key)) # validate that optimization metric is numeric if isinstance(return_val, dict): opt_val = return_val[optimization_key] else: opt_val = return_val return_val = {optimization_key: opt_val} if not isinstance(opt_val, constants.USER_FCT.NUMERIC_TYPES): raise exceptions.MetricTypeError(optimization_key, opt_val) # for key, value in return_val.items(): # return_val[key] = value if isinstance(value, str) else str(value) return_val["log"] = log_file.replace(env.project_path(), "") return_file = log_dir + "/.outputs.json" env.dump(json.dumps(return_val, default=json_default_numpy), return_file) metric_file = log_dir + "/.metric" env.dump(json.dumps(opt_val, default=json_default_numpy), metric_file) return opt_val
def wrapper_function(_: Any) -> None: """Patched function from dist_executor_fn factory. :param _: Necessary catch for the iterator given by Spark to the function upon foreach calls. Can safely be disregarded. """ EnvSing.get_instance().set_ml_id(app_id, run_id) partition_id, _ = util.get_partition_attempt_id() client = Client(server_addr, partition_id, 0, hb_interval, secret) log_file = log_dir + "/executor_" + str(partition_id) + ".log" builtin_print = __builtin__.print reporter = Reporter(log_file, partition_id, 0, builtin_print) def maggy_print(*args, **kwargs): builtin_print(*args, **kwargs) reporter.log(" ".join(str(x) for x in args), True) __builtin__.print = maggy_print try: _register_with_servers(client, reporter, partition_id) tb_logdir, trial_log_file = _setup_logging(reporter, log_dir) reporter.log("Awaiting worker reservations.", True) client.await_reservations() reporter.log("Reservations complete, configuring PyTorch.", True) master_config = client.get_exec_config()[0] if not master_config: reporter.log("RuntimeError: PyTorch registration failed.", True) raise RuntimeError("PyTorch registration failed.") addr, port = master_config["host_port"].split(":") torch_config = { "MASTER_ADDR": addr, "MASTER_PORT": port, "WORLD_SIZE": str(master_config["num_executors"]), "RANK": str(partition_id), "LOCAL_RANK": str(0), # DeepSpeed requires local rank. "NCCL_BLOCKING_WAIT": "1", "NCCL_DEBUG": "INFO", } tensorboard._register(tb_logdir) reporter.log(f"Torch config is {torch_config}", True) _setup_torch_env(torch_config) _sanitize_config(config) _init_cluster(timeout=60, random_seed=0) module = _wrap_module_dispatcher(config) _monkey_patch_pytorch(config.zero_lvl) reporter.log("Starting distributed training.", True) sig = inspect.signature(train_fn) if sig.parameters.get("reporter", None): retval = train_fn( module=module, hparams=config.hparams, train_set=config.train_set, test_set=config.test_set, reporter=reporter, ) else: retval = train_fn( module=module, hparams=config.hparams, train_set=config.train_set, test_set=config.test_set, ) retval = util.handle_return_val(retval, tb_logdir, "Metric", trial_log_file) dist.barrier( ) # Don't exit until all executors are done (else NCCL crashes) reporter.log("Finished distributed training.", True) client.finalize_metric(retval, reporter) except: # noqa: E722 reporter.log(traceback.format_exc()) raise finally: reporter.close_logger() client.stop() client.close()
def spark_wrapper_function(_: Any) -> None: """Patched function from tf_dist_executor_fn factory. :param _: Necessary catch for the iterator given by Spark to the function upon foreach calls. Can safely be disregarded. """ EnvSing.get_instance().set_ml_id(app_id, run_id) partition_id, _ = util.get_partition_attempt_id() client = EnvSing.get_instance().get_client( server_addr, partition_id, hb_interval, secret, socket.socket(socket.AF_INET, socket.SOCK_STREAM), ) log_file = log_dir + "/executor_" + str(partition_id) + ".log" reporter = Reporter(log_file, partition_id, 0, __builtin__.print) builtin_print = __builtin__.print _setup_logging(reporter, log_dir) def maggy_print(*args, **kwargs): builtin_print(*args, **kwargs) reporter.log(" ".join(str(x) for x in args), True) __builtin__.print = maggy_print try: host = EnvSing.get_instance().get_ip_address() tmp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) tmp_socket.bind(("", 0)) port = tmp_socket.getsockname()[1] + 1 host_port = host + ":" + str(port) _register_with_servers(client, reporter, partition_id) tb_logdir, trial_log_file = _setup_logging(reporter, log_dir) tensorboard._register(tb_logdir) reporter.log("Awaiting worker reservations.") client.await_reservations() reservations = client.get_message("RESERVATIONS") reporter.log(reservations) reporter.log(host_port) reporter.log("Reservations complete, configuring Tensorflow.") if not reservations: reporter.log("Tensorflow registration failed, exiting from all tasks.") return workers_host_port = [] for i in list(reservations["cluster"]): if len(reservations["cluster"][i]) > 0: workers_host_port.append(reservations["cluster"][i][0]) is_chief = False task_index = find_index(host_port, reservations) tf_config = reservations if task_index == -1: tf_config["task"] = {"type": "chief", "index": 0} is_chief = True else: tf_config["task"] = {"type": "worker", "index": task_index} last_worker_index = len(reservations["cluster"]["worker"]) - 1 if not last_worker_index < 0: evaluator_node = reservations["cluster"]["worker"][last_worker_index] reservations["cluster"]["evaluator"] = [evaluator_node] del reservations["cluster"]["worker"][last_worker_index] if evaluator_node == host_port: tf_config["task"] = {"type": "evaluator", "index": 0} reporter.log(f"Tensorflow config is {tf_config}") _setup_tf_config(tf_config) strategy = tf.distribute.MultiWorkerMirroredStrategy model = _wrap_model(config, strategy, is_chief) if config.dataset is not None and config.process_data is not None: config.dataset = _consume_data(config) reporter.log(f"index of slice {partition_id}") reporter.log("Starting distributed training.") sig = inspect.signature(train_fn) kwargs = {} if sig.parameters.get("model", None): kwargs["model"] = model if sig.parameters.get("dataset", None): kwargs["dataset"] = config.dataset if sig.parameters.get("hparams", None): kwargs["hparams"] = config.hparams if sig.parameters.get("reporter", None): kwargs["reporter"] = reporter retval = train_fn(**kwargs) else: retval = train_fn(**kwargs) # Set retval to work with util.handle_return_value, # if there is more than 1 metrics, retval will be a list and # retval[0] will contain the final loss retval_list = [] if isinstance(retval, dict): for item in retval.items(): retval_list.append(item[1]) retval = retval_list retval = {"Metric": retval[0] if isinstance(retval, list) else retval} retval = util.handle_return_val(retval, tb_logdir, "Metric", trial_log_file) reporter.log("Finished distributed training.") client.finalize_metric(retval, reporter) except: # noqa: E722 reporter.log(traceback.format_exc()) raise finally: reporter.close_logger() client.stop() client.close()