コード例 #1
0
def main() -> None:
    _task_commons._log_sys_info()
    task_type, task_id = cluster.get_task_description()
    with _internal.reserve_sock_addr() as host_port:
        client, cluster_spec, cluster_tasks = _task_commons._prepare_container(
            host_port)
        # Variable TF_CONFIG must be set before instantiating
        # the estimator to train in a distributed way
        cluster.setup_tf_config(cluster_spec)
        experiment = _task_commons._get_experiment(client)
        if isinstance(experiment, Experiment):
            session_config = experiment.config.session_config
        elif isinstance(experiment, KerasExperiment):
            raise ValueError(
                "KerasExperiment using parameter strategy is unsupported")
        else:
            raise ValueError(
                "experiment must be an Experiment or a KerasExperiment")
        _logger.info(f"Starting server {task_type}:{task_id}")

    cluster.start_tf_server(cluster_spec, session_config)
    thread = _task_commons._execute_dispatched_function(client, experiment)

    # "ps" tasks do not terminate by themselves. See
    # https://github.com/tensorflow/tensorflow/issues/4713.
    if task_type not in ['ps']:
        thread.join()
        _logger.info(f"{task_type}:{task_id} {thread.state}")

    _task_commons._shutdown_container(client, cluster_tasks, session_config,
                                      thread)
コード例 #2
0
def main() -> None:
    _task_commons._log_sys_info()
    task_type, task_id = cluster.get_task_description()
    task = cluster.get_task()
    client = skein.ApplicationClient.from_current()

    _task_commons._setup_container_logs(client)
    cluster_tasks = _task_commons._get_cluster_tasks(client)

    model_dir = os.getenv('TB_MODEL_DIR', "")
    if not model_dir:
        _logger.info("Read model_dir from estimator config")
        experiment = _task_commons._get_experiment(client)
        model_dir = experiment.estimator.config.model_dir

    _logger.info(f"Starting tensorboard on {model_dir}")

    thread = _internal.MonitoredThread(name=f"{task_type}:{task_id}",
                                       target=tensorboard.start_tf_board,
                                       args=(client, model_dir),
                                       daemon=True)
    thread.start()

    for cluster_task in cluster_tasks:
        event.wait(client, f"{cluster_task}/stop")

    timeout = tensorboard.get_termination_timeout()
    thread.join(timeout)

    event.stop_event(client, task, thread.exception)
    event.broadcast_container_stop_time(client, task)
コード例 #3
0
def main() -> None:
    _task_commons._log_sys_info()
    task_type, task_id = cluster.get_task_description()
    with _internal.reserve_sock_addr() as host_port:
        client, cluster_spec, cluster_tasks = _task_commons._prepare_container(host_port)
        cluster.setup_tf_config(cluster_spec)
        tf_session_config = cloudpickle.loads(client.kv.wait(constants.KV_TF_SESSION_CONFIG))
        _logger.info(f"tf_server_conf {tf_session_config}")

    tf.contrib.distribute.run_standard_tensorflow_server()
    event.wait(client, "stop")
コード例 #4
0
ファイル: _task_commons.py プロジェクト: rom1504/tf-yarn
def _execute_dispatched_function(client: skein.ApplicationClient,
                                 experiment: Experiment) -> MonitoredThread:
    task_type, task_id = cluster.get_task_description()
    _logger.info(f"Starting execution {task_type}:{task_id}")
    thread = MonitoredThread(name=f"{task_type}:{task_id}",
                             target=_gen_monitored_train_and_evaluate(client),
                             args=tuple(experiment),
                             daemon=True)
    thread.start()
    task = cluster.get_task()
    event.start_event(client, task)
    return thread
コード例 #5
0
def main():
    client = skein.ApplicationClient.from_current()
    task = cluster.get_task()
    task_type, task_id = cluster.get_task_description()
    event.init_event(client, task, "127.0.0.1:0")
    _task_commons._setup_container_logs(client)

    if task_type == "evaluator":
        evaluator_fn(client)
    else:
        logger.info(f"{task_type}:{task_id}: nothing to do")

    event.stop_event(client, task, None)
コード例 #6
0
def main():
    client = skein.ApplicationClient.from_current()
    task_type, task_id = cluster.get_task_description()
    task = cluster.get_task()
    event.init_event(client, task, f"127.0.0.1:0")
    _task_commons._setup_container_logs(client)

    if task_type in ['chief', 'worker']:
        _worker_fn(task_type, task_id, client)
    elif task_type == 'evaluator':
        _evaluator_fn(client)
    else:
        logger.error(f'Unknown task type {task_type}')

    event.stop_event(client, task, None)
コード例 #7
0
ファイル: _task_commons.py プロジェクト: jcuquemelle/tf-yarn
def _execute_dispatched_function(
        client: skein.ApplicationClient,
        experiment: Union[Experiment, KerasExperiment]) -> MonitoredThread:
    task_type, task_id = cluster.get_task_description()
    _logger.info(f"Starting execution {task_type}:{task_id}")
    if isinstance(experiment, Experiment):
        thread = MonitoredThread(
            name=f"{task_type}:{task_id}",
            target=_gen_monitored_train_and_evaluate(client),
            args=tuple(experiment),
            daemon=True)
    elif isinstance(experiment, KerasExperiment):
        raise ValueError(
            "KerasExperiment using parameter strategy is unsupported")
    else:
        raise ValueError(
            "experiment must be an Experiment or a KerasExperiment")
    thread.start()
    task = cluster.get_task()
    event.start_event(client, task)
    return thread
コード例 #8
0
ファイル: test_cluster.py プロジェクト: rom1504/tf-yarn
def test_get_task_description():
    with mock.patch.dict(os.environ):
        os.environ["SKEIN_CONTAINER_ID"] = "MYTASK_42"
        assert "MYTASK", 42 == cluster.get_task_description()