def _setup_logging(reporter: Reporter, log_dir: str) -> Tuple[str, str]: """Sets up logging directories and files. :param reporter: Reporter responsible for logging. :param log_dir: Log directory path on the file system. :returns: Tuple containing the path of the tensorboard directory and the trial log file. """ reporter.set_trial_id(0) tb_logdir = log_dir + "/" + "training_logs_" + str(reporter.partition_id) trial_log_file = tb_logdir + "/output.log" reporter.set_trial_id(0) # If trial is repeated, delete trial directory, except log file if EnvSing.get_instance().exists(tb_logdir): util.clean_dir(tb_logdir, [trial_log_file]) else: EnvSing.get_instance().mkdir(tb_logdir) reporter.init_logger(trial_log_file) return tb_logdir, trial_log_file
def _wrapper_fun(_: Any) -> None: """Patched function from trial_executor_fn factory. :param _: Necessary catch for the iterator given by Spark to the function upon foreach calls. Can safely be disregarded. """ env = EnvSing.get_instance() env.set_ml_id(app_id, run_id) # get task context information to determine executor identifier partition_id, task_attempt = util.get_partition_attempt_id() client = rpc.Client(server_addr, partition_id, task_attempt, hb_interval, secret) log_file = (log_dir + "/executor_" + str(partition_id) + "_" + str(task_attempt) + ".log") # save the builtin print original_print = __builtin__.print reporter = Reporter(log_file, partition_id, task_attempt, original_print) def maggy_print(*args, **kwargs): """Maggy custom print() function.""" original_print(*args, **kwargs) reporter.log(" ".join(str(x) for x in args), True) # override the builtin print __builtin__.print = maggy_print try: client_addr = client.client_addr host_port = client_addr[0] + ":" + str(client_addr[1]) exec_spec = {} exec_spec["partition_id"] = partition_id exec_spec["task_attempt"] = task_attempt exec_spec["host_port"] = host_port exec_spec["trial_id"] = None reporter.log("Registering with experiment driver", False) client.register(exec_spec) client.start_heartbeat(reporter) # blocking trial_id, parameters = client.get_suggestion(reporter) while not client.done: if experiment_type == "ablation": ablation_params = { "ablated_feature": parameters.get("ablated_feature", "None"), "ablated_layer": parameters.get("ablated_layer", "None"), } parameters.pop("ablated_feature") parameters.pop("ablated_layer") tb_logdir = log_dir + "/" + trial_id trial_log_file = tb_logdir + "/output.log" reporter.set_trial_id(trial_id) # If trial is repeated, delete trial directory, except log file if env.exists(tb_logdir): util.clean_dir(tb_logdir, [trial_log_file]) else: env.mkdir(tb_logdir) reporter.init_logger(trial_log_file) tensorboard._register(tb_logdir) if experiment_type == "ablation": env.dump( json.dumps(ablation_params, default=util.json_default_numpy), tb_logdir + "/.hparams.json", ) else: env.dump( json.dumps(parameters, default=util.json_default_numpy), tb_logdir + "/.hparams.json", ) try: reporter.log("Starting Trial: {}".format(trial_id), False) reporter.log("Trial Configuration: {}".format(parameters), False) if experiment_type == "optimization": tensorboard._write_hparams(parameters, trial_id) sig = inspect.signature(train_fn) if sig.parameters.get("reporter", None): retval = train_fn(**parameters, reporter=reporter) else: retval = train_fn(**parameters) retval = util.handle_return_val(retval, tb_logdir, optimization_key, trial_log_file) except exceptions.EarlyStopException as e: retval = e.metric reporter.log("Early Stopped Trial.", False) reporter.log("Finished Trial: {}".format(trial_id), False) reporter.log("Final Metric: {}".format(retval), False) client.finalize_metric(retval, reporter) # blocking trial_id, parameters = client.get_suggestion(reporter) except: # noqa: E722 reporter.log(traceback.format_exc(), False) raise finally: reporter.close_logger() client.stop() client.close()