Exemple #1
0
def _setup_logging(reporter: Reporter, log_dir: str) -> Tuple[str, str]:
    """Sets up logging directories and files.

    :param reporter: Reporter responsible for logging.
    :param log_dir: Log directory path on the file system.

    :returns: Tuple containing the path of the tensorboard directory
        and the trial log file.
    """
    reporter.set_trial_id(0)
    tb_logdir = log_dir + "/" + "training_logs_" + str(reporter.partition_id)
    trial_log_file = tb_logdir + "/output.log"
    reporter.set_trial_id(0)
    # If trial is repeated, delete trial directory, except log file
    if EnvSing.get_instance().exists(tb_logdir):
        util.clean_dir(tb_logdir, [trial_log_file])
    else:
        EnvSing.get_instance().mkdir(tb_logdir)
    reporter.init_logger(trial_log_file)
    return tb_logdir, trial_log_file
Exemple #2
0
    def _wrapper_fun(_: Any) -> None:
        """Patched function from trial_executor_fn factory.

        :param _: Necessary catch for the iterator given by Spark to the
        function upon foreach calls. Can safely be disregarded.
        """
        env = EnvSing.get_instance()

        env.set_ml_id(app_id, run_id)

        # get task context information to determine executor identifier
        partition_id, task_attempt = util.get_partition_attempt_id()

        client = rpc.Client(server_addr, partition_id, task_attempt,
                            hb_interval, secret)
        log_file = (log_dir + "/executor_" + str(partition_id) + "_" +
                    str(task_attempt) + ".log")

        # save the builtin print
        original_print = __builtin__.print

        reporter = Reporter(log_file, partition_id, task_attempt,
                            original_print)

        def maggy_print(*args, **kwargs):
            """Maggy custom print() function."""
            original_print(*args, **kwargs)
            reporter.log(" ".join(str(x) for x in args), True)

        # override the builtin print
        __builtin__.print = maggy_print

        try:
            client_addr = client.client_addr

            host_port = client_addr[0] + ":" + str(client_addr[1])

            exec_spec = {}
            exec_spec["partition_id"] = partition_id
            exec_spec["task_attempt"] = task_attempt
            exec_spec["host_port"] = host_port
            exec_spec["trial_id"] = None

            reporter.log("Registering with experiment driver", False)
            client.register(exec_spec)

            client.start_heartbeat(reporter)

            # blocking
            trial_id, parameters = client.get_suggestion(reporter)

            while not client.done:
                if experiment_type == "ablation":
                    ablation_params = {
                        "ablated_feature":
                        parameters.get("ablated_feature", "None"),
                        "ablated_layer":
                        parameters.get("ablated_layer", "None"),
                    }
                    parameters.pop("ablated_feature")
                    parameters.pop("ablated_layer")

                tb_logdir = log_dir + "/" + trial_id
                trial_log_file = tb_logdir + "/output.log"
                reporter.set_trial_id(trial_id)

                # If trial is repeated, delete trial directory, except log file
                if env.exists(tb_logdir):
                    util.clean_dir(tb_logdir, [trial_log_file])
                else:
                    env.mkdir(tb_logdir)

                reporter.init_logger(trial_log_file)
                tensorboard._register(tb_logdir)
                if experiment_type == "ablation":
                    env.dump(
                        json.dumps(ablation_params,
                                   default=util.json_default_numpy),
                        tb_logdir + "/.hparams.json",
                    )

                else:
                    env.dump(
                        json.dumps(parameters,
                                   default=util.json_default_numpy),
                        tb_logdir + "/.hparams.json",
                    )

                try:
                    reporter.log("Starting Trial: {}".format(trial_id), False)
                    reporter.log("Trial Configuration: {}".format(parameters),
                                 False)

                    if experiment_type == "optimization":
                        tensorboard._write_hparams(parameters, trial_id)

                    sig = inspect.signature(train_fn)
                    if sig.parameters.get("reporter", None):
                        retval = train_fn(**parameters, reporter=reporter)
                    else:
                        retval = train_fn(**parameters)

                    retval = util.handle_return_val(retval, tb_logdir,
                                                    optimization_key,
                                                    trial_log_file)

                except exceptions.EarlyStopException as e:
                    retval = e.metric
                    reporter.log("Early Stopped Trial.", False)

                reporter.log("Finished Trial: {}".format(trial_id), False)
                reporter.log("Final Metric: {}".format(retval), False)
                client.finalize_metric(retval, reporter)

                # blocking
                trial_id, parameters = client.get_suggestion(reporter)

        except:  # noqa: E722
            reporter.log(traceback.format_exc(), False)
            raise
        finally:
            reporter.close_logger()
            client.stop()
            client.close()