def _exp_startup_callback(self) -> None:
     """Registers the hp config to tensorboard upon experiment startup."""
     tensorboard._register(EnvSing.get_instance().get_logdir(
         self.app_id, self.run_id))
     tensorboard._write_hparams_config(
         EnvSing.get_instance().get_logdir(self.app_id, self.run_id),
         self.config.searchspace,
     )
示例#2
0
def register_environment(app_id, run_id):
    """Validates IDs and creates an experiment folder in the fs.

    Args:
        :app_id: Application ID
        :run_id: Current experiment run ID

    Returns: (app_id, run_id) with the updated IDs.
    """
    app_id = str(find_spark().sparkContext.applicationId)
    app_id, run_id = validate_ml_id(app_id, run_id)
    set_ml_id(app_id, run_id)
    # Create experiment directory.
    EnvSing.get_instance().create_experiment_dir(app_id, run_id)
    tensorboard._register(EnvSing.get_instance().get_logdir(app_id, run_id))
    return app_id, run_id
示例#3
0
    def wrapper_function(_: Any) -> None:
        """Patched function from dist_executor_fn factory.

        :param _: Necessary catch for the iterator given by Spark to the
        function upon foreach calls. Can safely be disregarded.
        """
        EnvSing.get_instance().set_ml_id(app_id, run_id)
        partition_id, _ = util.get_partition_attempt_id()
        client = Client(server_addr, partition_id, 0, hb_interval, secret)
        log_file = log_dir + "/executor_" + str(partition_id) + ".log"

        builtin_print = __builtin__.print
        reporter = Reporter(log_file, partition_id, 0, builtin_print)

        def maggy_print(*args, **kwargs):
            builtin_print(*args, **kwargs)
            reporter.log(" ".join(str(x) for x in args), True)

        __builtin__.print = maggy_print

        try:
            _register_with_servers(client, reporter, partition_id)
            tb_logdir, trial_log_file = _setup_logging(reporter, log_dir)
            reporter.log("Awaiting worker reservations.", True)
            client.await_reservations()
            reporter.log("Reservations complete, configuring PyTorch.", True)
            master_config = client.get_exec_config()[0]
            if not master_config:
                reporter.log("RuntimeError: PyTorch registration failed.",
                             True)
                raise RuntimeError("PyTorch registration failed.")
            addr, port = master_config["host_port"].split(":")
            torch_config = {
                "MASTER_ADDR": addr,
                "MASTER_PORT": port,
                "WORLD_SIZE": str(master_config["num_executors"]),
                "RANK": str(partition_id),
                "LOCAL_RANK": str(0),  # DeepSpeed requires local rank.
                "NCCL_BLOCKING_WAIT": "1",
                "NCCL_DEBUG": "INFO",
            }
            tensorboard._register(tb_logdir)
            reporter.log(f"Torch config is {torch_config}", True)

            _setup_torch_env(torch_config)
            _sanitize_config(config)
            _init_cluster(timeout=60, random_seed=0)
            module = _wrap_module_dispatcher(config)
            _monkey_patch_pytorch(config.zero_lvl)

            reporter.log("Starting distributed training.", True)
            sig = inspect.signature(train_fn)
            if sig.parameters.get("reporter", None):
                retval = train_fn(
                    module=module,
                    hparams=config.hparams,
                    train_set=config.train_set,
                    test_set=config.test_set,
                    reporter=reporter,
                )
            else:
                retval = train_fn(
                    module=module,
                    hparams=config.hparams,
                    train_set=config.train_set,
                    test_set=config.test_set,
                )

            retval = util.handle_return_val(retval, tb_logdir, "Metric",
                                            trial_log_file)
            dist.barrier(
            )  # Don't exit until all executors are done (else NCCL crashes)
            reporter.log("Finished distributed training.", True)
            client.finalize_metric(retval, reporter)
        except:  # noqa: E722
            reporter.log(traceback.format_exc())
            raise
        finally:
            reporter.close_logger()
            client.stop()
            client.close()
示例#4
0
    def _wrapper_fun(iter):
        """
        Wraps the user supplied training function in order to be passed to the
        Spark Executors.

        Args:
            iter:

        Returns:

        """
        experiment_utils._set_ml_id(app_id, run_id)

        # get task context information to determine executor identifier
        partition_id, task_attempt = util.get_partition_attempt_id()

        client = rpc.Client(server_addr, partition_id, task_attempt,
                            hb_interval, secret)
        log_file = (log_dir + "/executor_" + str(partition_id) + "_" +
                    str(task_attempt) + ".log")

        # save the builtin print
        original_print = __builtin__.print

        reporter = Reporter(log_file, partition_id, task_attempt,
                            original_print)

        def maggy_print(*args, **kwargs):
            """Maggy custom print() function."""
            original_print(*args, **kwargs)
            reporter.log(" ".join(str(x) for x in args), True)

        # override the builtin print
        __builtin__.print = maggy_print

        try:
            client_addr = client.client_addr

            host_port = client_addr[0] + ":" + str(client_addr[1])

            exec_spec = {}
            exec_spec["partition_id"] = partition_id
            exec_spec["task_attempt"] = task_attempt
            exec_spec["host_port"] = host_port
            exec_spec["trial_id"] = None

            reporter.log("Registering with experiment driver", False)
            client.register(exec_spec)

            client.start_heartbeat(reporter)

            # blocking
            trial_id, parameters = client.get_suggestion(reporter)

            while not client.done:
                if experiment_type == "ablation":
                    ablation_params = {
                        "ablated_feature":
                        parameters.get("ablated_feature", "None"),
                        "ablated_layer":
                        parameters.get("ablated_layer", "None"),
                    }
                    parameters.pop("ablated_feature")
                    parameters.pop("ablated_layer")

                tb_logdir = log_dir + "/" + trial_id
                trial_log_file = tb_logdir + "/output.log"
                reporter.set_trial_id(trial_id)

                # If trial is repeated, delete trial directory, except log file
                if hopshdfs.exists(tb_logdir):
                    util._clean_dir(tb_logdir, [trial_log_file])
                else:
                    hopshdfs.mkdir(tb_logdir)

                reporter.init_logger(trial_log_file)
                tensorboard._register(tb_logdir)
                if experiment_type == "ablation":
                    hopshdfs.dump(
                        json.dumps(ablation_params,
                                   default=util.json_default_numpy),
                        tb_logdir + "/.hparams.json",
                    )

                else:
                    hopshdfs.dump(
                        json.dumps(parameters,
                                   default=util.json_default_numpy),
                        tb_logdir + "/.hparams.json",
                    )

                try:
                    reporter.log("Starting Trial: {}".format(trial_id), False)
                    reporter.log("Trial Configuration: {}".format(parameters),
                                 False)

                    if experiment_type == "optimization":
                        tensorboard._write_hparams(parameters, trial_id)

                    sig = inspect.signature(map_fun)
                    if sig.parameters.get("reporter", None):
                        retval = map_fun(**parameters, reporter=reporter)
                    else:
                        retval = map_fun(**parameters)

                    if experiment_type == "optimization":
                        tensorboard._write_session_end()

                    retval = util._handle_return_val(retval, tb_logdir,
                                                     optimization_key,
                                                     trial_log_file)

                except exceptions.EarlyStopException as e:
                    retval = e.metric
                    reporter.log("Early Stopped Trial.", False)

                reporter.log("Finished Trial: {}".format(trial_id), False)
                reporter.log("Final Metric: {}".format(retval), False)
                client.finalize_metric(retval, reporter)

                # blocking
                trial_id, parameters = client.get_suggestion(reporter)

        except:  # noqa: E722
            reporter.log(traceback.format_exc(), False)
            raise
        finally:
            reporter.close_logger()
            client.stop()
            client.close()
示例#5
0
 def init_ml_tracking(self, app_id, run_id):
     tensorboard._register(experiment_utils._get_logdir(app_id, run_id))
示例#6
0
    def spark_wrapper_function(_: Any) -> None:
        """Patched function from tf_dist_executor_fn factory.

        :param _: Necessary catch for the iterator given by Spark to the
        function upon foreach calls. Can safely be disregarded.
        """
        EnvSing.get_instance().set_ml_id(app_id, run_id)
        partition_id, _ = util.get_partition_attempt_id()
        client = EnvSing.get_instance().get_client(
            server_addr,
            partition_id,
            hb_interval,
            secret,
            socket.socket(socket.AF_INET, socket.SOCK_STREAM),
        )
        log_file = log_dir + "/executor_" + str(partition_id) + ".log"

        reporter = Reporter(log_file, partition_id, 0, __builtin__.print)
        builtin_print = __builtin__.print
        _setup_logging(reporter, log_dir)

        def maggy_print(*args, **kwargs):
            builtin_print(*args, **kwargs)
            reporter.log(" ".join(str(x) for x in args), True)

        __builtin__.print = maggy_print

        try:
            host = EnvSing.get_instance().get_ip_address()

            tmp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            tmp_socket.bind(("", 0))
            port = tmp_socket.getsockname()[1] + 1

            host_port = host + ":" + str(port)

            _register_with_servers(client, reporter, partition_id)
            tb_logdir, trial_log_file = _setup_logging(reporter, log_dir)
            tensorboard._register(tb_logdir)

            reporter.log("Awaiting worker reservations.")
            client.await_reservations()

            reservations = client.get_message("RESERVATIONS")
            reporter.log(reservations)
            reporter.log(host_port)
            reporter.log("Reservations complete, configuring Tensorflow.")

            if not reservations:
                reporter.log("Tensorflow registration failed, exiting from all tasks.")
                return

            workers_host_port = []

            for i in list(reservations["cluster"]):
                if len(reservations["cluster"][i]) > 0:
                    workers_host_port.append(reservations["cluster"][i][0])

            is_chief = False
            task_index = find_index(host_port, reservations)
            tf_config = reservations
            if task_index == -1:
                tf_config["task"] = {"type": "chief", "index": 0}
                is_chief = True
            else:
                tf_config["task"] = {"type": "worker", "index": task_index}

            last_worker_index = len(reservations["cluster"]["worker"]) - 1
            if not last_worker_index < 0:
                evaluator_node = reservations["cluster"]["worker"][last_worker_index]
                reservations["cluster"]["evaluator"] = [evaluator_node]
                del reservations["cluster"]["worker"][last_worker_index]
                if evaluator_node == host_port:
                    tf_config["task"] = {"type": "evaluator", "index": 0}

            reporter.log(f"Tensorflow config is {tf_config}")

            _setup_tf_config(tf_config)

            strategy = tf.distribute.MultiWorkerMirroredStrategy
            model = _wrap_model(config, strategy, is_chief)

            if config.dataset is not None and config.process_data is not None:
                config.dataset = _consume_data(config)

            reporter.log(f"index of slice {partition_id}")
            reporter.log("Starting distributed training.")
            sig = inspect.signature(train_fn)

            kwargs = {}
            if sig.parameters.get("model", None):
                kwargs["model"] = model
            if sig.parameters.get("dataset", None):
                kwargs["dataset"] = config.dataset
            if sig.parameters.get("hparams", None):
                kwargs["hparams"] = config.hparams

            if sig.parameters.get("reporter", None):
                kwargs["reporter"] = reporter
                retval = train_fn(**kwargs)
            else:
                retval = train_fn(**kwargs)

            # Set retval to work with util.handle_return_value,
            # if there is more than 1 metrics, retval will be a list and
            # retval[0] will contain the final loss
            retval_list = []
            if isinstance(retval, dict):
                for item in retval.items():
                    retval_list.append(item[1])
                retval = retval_list
            retval = {"Metric": retval[0] if isinstance(retval, list) else retval}
            retval = util.handle_return_val(retval, tb_logdir, "Metric", trial_log_file)
            reporter.log("Finished distributed training.")
            client.finalize_metric(retval, reporter)
        except:  # noqa: E722
            reporter.log(traceback.format_exc())
            raise
        finally:
            reporter.close_logger()
            client.stop()
            client.close()
示例#7
0
    def python_wrapper_function(_: Any) -> None:
        """Patched function from tf_dist_executor_fn factory.

        :param _: Necessary catch for the iterator given by Spark to the
        function upon foreach calls. Can safely be disregarded.
        """
        EnvSing.get_instance().set_ml_id(app_id, run_id)
        partition_id, _ = util.get_partition_attempt_id()

        log_file = log_dir + "/executor_" + str(partition_id) + ".log"

        reporter = Reporter(log_file, partition_id, 0, __builtin__.print)
        builtin_print = __builtin__.print

        def maggy_print(*args, **kwargs):
            builtin_print(*args, **kwargs)
            reporter.log(" ".join(str(x) for x in args), True)

        __builtin__.print = maggy_print

        try:
            tmp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            tmp_socket.bind(("", 0))

            tb_logdir, trial_log_file = _setup_logging(reporter, log_dir)
            tensorboard._register(tb_logdir)
            tf_config = None

            physical_devices = tf.config.list_physical_devices("GPU")
            if physical_devices is not None:
                strategy = tf.distribute.MultiWorkerMirroredStrategy
                for count, pd in enumerate(physical_devices):
                    if pd == "/gpu:0":
                        tf_config["task"] = {"type": "chief", "index": 0}
                    else:
                        tf_config["task"] = {"type": "worker", "index": count}
            else:  # Use the Default Strategy
                strategy = tf.distribute.get_strategy

            model = _wrap_model(config, strategy, False)

            if config.dataset is not None and config.process_data is not None:
                config.dataset = _consume_data(config)

            reporter.log(f"index of slice {partition_id}")
            reporter.log("Starting distributed training.")
            sig = inspect.signature(train_fn)

            kwargs = {}
            if sig.parameters.get("model", None):
                kwargs["model"] = model
            if sig.parameters.get("dataset", None):
                kwargs["dataset"] = config.dataset
            if sig.parameters.get("hparams", None):
                kwargs["hparams"] = config.hparams

            if sig.parameters.get("reporter", None):
                kwargs["reporter"] = reporter
                retval = train_fn(**kwargs)
            else:
                retval = train_fn(**kwargs)

            # Set retval to work with util.handle_return_value,
            # if there is more than 1 metrics, retval will be a list and
            # retval[0] will contain the final loss
            retval_list = []
            if isinstance(retval, dict):
                for item in retval.items():
                    retval_list.append(item[1])
                retval = retval_list
            retval = {"Metric": retval[0] if isinstance(retval, list) else retval}
            retval = util.handle_return_val(retval, tb_logdir, "Metric", trial_log_file)
            reporter.log("Finished distributed training.")
        except:  # noqa: E722
            reporter.log(traceback.format_exc())
            raise
        finally:
            reporter.close_logger()
            __builtin__.print = builtin_print
        return retval
示例#8
0
def lagom(
    map_fun,
    name="no-name",
    experiment_type="optimization",
    searchspace=None,
    optimizer=None,
    direction="max",
    num_trials=1,
    ablation_study=None,
    ablator=None,
    optimization_key="metric",
    hb_interval=1,
    es_policy="median",
    es_interval=300,
    es_min=10,
    description="",
):
    """Launches a maggy experiment, which depending on `experiment_type` can
    either be a hyperparameter optimization or an ablation study experiment.
    Given a search space, objective and a model training procedure `map_fun`
    (black-box function), an experiment is the whole process of finding the
    best hyperparameter combination in the search space, optimizing the
    black-box function. Currently maggy supports random search and a median
    stopping rule.

    **lagom** is a Swedish word meaning "just the right amount".

    :param map_fun: User defined experiment containing the model training.
    :type map_fun: function
    :param name: A user defined experiment identifier.
    :type name: str
    :param experiment_type: Type of Maggy experiment, either 'optimization'
        (default) or 'ablation'.
    :type experiment_type: str
    :param searchspace: A maggy Searchspace object from which samples are
        drawn.
    :type searchspace: Searchspace
    :param optimizer: The optimizer is the part generating new trials.
    :type optimizer: str, AbstractOptimizer
    :param direction: If set to ‘max’ the highest value returned will
        correspond to the best solution, if set to ‘min’ the opposite is true.
    :type direction: str
    :param num_trials: the number of trials to evaluate given the search space,
        each containing a different hyperparameter combination
    :type num_trials: int
    :param ablation_study: Ablation study object. Can be None for optimization
        experiment type.
    :type ablation_study: AblationStudy
    :param ablator: Ablator to use for experiment type 'ablation'.
    :type ablator: str, AbstractAblator
    :param optimization_key: Name of the metric to be optimized
    :type optimization_key: str, optional
    :param hb_interval: The heartbeat interval in seconds from trial executor
        to experiment driver, defaults to 1
    :type hb_interval: int, optional
    :param es_policy: The earlystopping policy, defaults to 'median'
    :type es_policy: str, optional
    :param es_interval: Frequency interval in seconds to check currently
        running trials for early stopping, defaults to 300
    :type es_interval: int, optional
    :param es_min: Minimum number of trials finalized before checking for
        early stopping, defaults to 10
    :type es_min: int, optional
    :param description: A longer description of the experiment.
    :type description: str, optional
    :raises RuntimeError: An experiment is currently running.
    :return: A dictionary indicating the best trial and best hyperparameter
        combination with it's performance metric
    :rtype: dict
    """
    global running
    if running:
        raise RuntimeError("An experiment is currently running.")

    job_start = time.time()
    sc = hopsutil._find_spark().sparkContext
    exp_driver = None

    try:
        global app_id
        global experiment_json
        global run_id
        app_id = str(sc.applicationId)

        app_id, run_id = util._validate_ml_id(app_id, run_id)

        # start run
        running = True
        experiment_utils._set_ml_id(app_id, run_id)

        # create experiment dir
        experiment_utils._create_experiment_dir(app_id, run_id)

        tensorboard._register(experiment_utils._get_logdir(app_id, run_id))

        num_executors = util.num_executors(sc)

        # start experiment driver
        if experiment_type == "optimization":

            assert num_trials > 0, "number of trials should be greater " + "than zero"
            tensorboard._write_hparams_config(
                experiment_utils._get_logdir(app_id, run_id), searchspace
            )

            if num_executors > num_trials:
                num_executors = num_trials

            exp_driver = experimentdriver.ExperimentDriver(
                "optimization",
                searchspace=searchspace,
                optimizer=optimizer,
                direction=direction,
                num_trials=num_trials,
                name=name,
                num_executors=num_executors,
                hb_interval=hb_interval,
                es_policy=es_policy,
                es_interval=es_interval,
                es_min=es_min,
                description=description,
                log_dir=experiment_utils._get_logdir(app_id, run_id),
            )

            exp_function = exp_driver.optimizer.name()

        elif experiment_type == "ablation":
            exp_driver = experimentdriver.ExperimentDriver(
                "ablation",
                ablation_study=ablation_study,
                ablator=ablator,
                name=name,
                num_executors=num_executors,
                hb_interval=hb_interval,
                description=description,
                log_dir=experiment_utils._get_logdir(app_id, run_id),
            )
            # using exp_driver.num_executor since
            # it has been set using ablator.get_number_of_trials()
            # in experiment.py
            if num_executors > exp_driver.num_executors:
                num_executors = exp_driver.num_executors

            exp_function = exp_driver.ablator.name()
        else:
            running = False
            raise RuntimeError(
                "Unknown experiment_type:"
                "should be either 'optimization' or 'ablation', "
                "But it is '{0}'".format(str(experiment_type))
            )

        nodeRDD = sc.parallelize(range(num_executors), num_executors)

        # Do provenance after initializing exp_driver, because exp_driver does
        # the type checks for optimizer and searchspace
        sc.setJobGroup(os.environ["ML_ID"], "{0} | {1}".format(name, exp_function))

        experiment_json = experiment_utils._populate_experiment(
            name,
            exp_function,
            "MAGGY",
            exp_driver.searchspace.json(),
            description,
            app_id,
            direction,
            optimization_key,
        )

        experiment_json = experiment_utils._attach_experiment_xattr(
            app_id, run_id, experiment_json, "CREATE"
        )

        util._log(
            "Started Maggy Experiment: {0}, {1}, run {2}".format(name, app_id, run_id)
        )

        exp_driver.init(job_start)

        server_addr = exp_driver.server_addr

        # Force execution on executor, since GPU is located on executor
        nodeRDD.foreachPartition(
            trialexecutor._prepare_func(
                app_id,
                run_id,
                experiment_type,
                map_fun,
                server_addr,
                hb_interval,
                exp_driver._secret,
                optimization_key,
                experiment_utils._get_logdir(app_id, run_id),
            )
        )
        job_end = time.time()

        result = exp_driver.finalize(job_end)
        best_logdir = (
            experiment_utils._get_logdir(app_id, run_id) + "/" + result["best_id"]
        )

        util._finalize_experiment(
            experiment_json,
            float(result["best_val"]),
            app_id,
            run_id,
            "FINISHED",
            exp_driver.duration,
            experiment_utils._get_logdir(app_id, run_id),
            best_logdir,
            optimization_key,
        )

        util._log("Finished Experiment")

        return result

    except:  # noqa: E722
        _exception_handler(
            experiment_utils._seconds_to_milliseconds(time.time() - job_start)
        )
        if exp_driver:
            if exp_driver.exception:
                raise exp_driver.exception
        raise
    finally:
        # grace period to send last logs to sparkmagic
        # sparkmagic hb poll intervall is 5 seconds, therefore wait 6 seconds
        time.sleep(6)
        # cleanup spark jobs
        if running and exp_driver is not None:
            exp_driver.stop()
        run_id += 1
        running = False
        sc.setJobGroup("", "")

    return result
示例#9
0
    def _wrapper_fun(_: Any) -> None:
        """Patched function from trial_executor_fn factory.

        :param _: Necessary catch for the iterator given by Spark to the
        function upon foreach calls. Can safely be disregarded.
        """
        env = EnvSing.get_instance()

        env.set_ml_id(app_id, run_id)

        # get task context information to determine executor identifier
        partition_id, task_attempt = util.get_partition_attempt_id()

        client = EnvSing.get_instance().get_client(
            server_addr,
            partition_id,
            hb_interval,
            secret,
            socket.socket(socket.AF_INET, socket.SOCK_STREAM),
        )
        log_file = (log_dir + "/executor_" + str(partition_id) + "_" +
                    str(task_attempt) + ".log")

        # save the builtin print
        original_print = __builtin__.print

        reporter = Reporter(log_file, partition_id, task_attempt,
                            original_print)

        def maggy_print(*args, **kwargs):
            """Maggy custom print() function."""
            original_print(*args, **kwargs)
            reporter.log(" ".join(str(x) for x in args), True)

        # override the builtin print
        __builtin__.print = maggy_print

        try:
            client_addr = client.client_addr

            host_port = client_addr[0] + ":" + str(client_addr[1])

            exec_spec = {}
            exec_spec["partition_id"] = partition_id
            exec_spec["task_attempt"] = task_attempt
            exec_spec["host_port"] = host_port
            exec_spec["trial_id"] = None

            reporter.log("Registering with experiment driver", False)
            client.register(exec_spec)

            client.start_heartbeat(reporter)

            # blocking
            trial_id, parameters = client.get_suggestion(reporter)
            while not client.done:
                if experiment_type == "ablation":
                    ablation_params = {
                        "ablated_feature":
                        parameters.get("ablated_feature", "None"),
                        "ablated_layer":
                        parameters.get("ablated_layer", "None"),
                    }
                    parameters.pop("ablated_feature")
                    parameters.pop("ablated_layer")

                tb_logdir = log_dir + "/" + trial_id
                trial_log_file = tb_logdir + "/output.log"
                reporter.set_trial_id(trial_id)

                # If trial is repeated, delete trial directory, except log file
                if env.exists(tb_logdir):
                    util.clean_dir(tb_logdir, [trial_log_file])
                else:
                    env.mkdir(tb_logdir)

                reporter.init_logger(trial_log_file)
                tensorboard._register(tb_logdir)
                if experiment_type == "ablation":
                    env.dump(
                        json.dumps(ablation_params,
                                   default=util.json_default_numpy),
                        tb_logdir + "/.hparams.json",
                    )

                else:
                    env.dump(
                        json.dumps(parameters,
                                   default=util.json_default_numpy),
                        tb_logdir + "/.hparams.json",
                    )

                model = config.model
                dataset = config.dataset

                try:
                    reporter.log("Starting Trial: {}".format(trial_id), False)
                    reporter.log("Trial Configuration: {}".format(parameters),
                                 False)

                    if experiment_type == "optimization":
                        tensorboard._write_hparams(parameters, trial_id)

                    sig = inspect.signature(train_fn)
                    kwargs = {}
                    if sig.parameters.get("model", None):
                        kwargs["model"] = model
                    if sig.parameters.get("dataset", None):
                        kwargs["dataset"] = dataset
                    if sig.parameters.get("hparams", None):
                        kwargs["hparams"] = parameters

                    if sig.parameters.get("reporter", None):
                        kwargs["reporter"] = reporter
                        retval = train_fn(**kwargs)
                    else:
                        retval = train_fn(**kwargs)

                    # todo: test this change
                    retval_list = []
                    if isinstance(retval, dict):
                        for item in retval.items():
                            retval_list.append(item[1])
                        retval = retval_list
                    retval = {
                        "Metric":
                        retval[0] if isinstance(retval, list) else retval
                    }
                    retval = util.handle_return_val(retval, tb_logdir,
                                                    optimization_key,
                                                    trial_log_file)

                except exceptions.EarlyStopException as e:
                    retval = e.metric
                    reporter.log("Early Stopped Trial.", False)

                reporter.log("Finished Trial: {}".format(trial_id), False)
                reporter.log("Final Metric: {}".format(retval), False)
                client.finalize_metric(retval, reporter)

                # blocking
                trial_id, parameters = client.get_suggestion(reporter)

        except:  # noqa: E722
            reporter.log(traceback.format_exc(), False)
            raise
        finally:
            reporter.close_logger()
            client.stop()
            client.close()