Esempio n. 1
0
def main() -> None:
    _task_commons._log_sys_info()
    task_type, task_id = cluster.get_task_description()
    with _internal.reserve_sock_addr() as host_port:
        client, cluster_spec, cluster_tasks = _task_commons._prepare_container(
            host_port)
        # Variable TF_CONFIG must be set before instantiating
        # the estimator to train in a distributed way
        cluster.setup_tf_config(cluster_spec)
        experiment = _task_commons._get_experiment(client)
        if isinstance(experiment, Experiment):
            session_config = experiment.config.session_config
        elif isinstance(experiment, KerasExperiment):
            raise ValueError(
                "KerasExperiment using parameter strategy is unsupported")
        else:
            raise ValueError(
                "experiment must be an Experiment or a KerasExperiment")
        _logger.info(f"Starting server {task_type}:{task_id}")

    cluster.start_tf_server(cluster_spec, session_config)
    thread = _task_commons._execute_dispatched_function(client, experiment)

    # "ps" tasks do not terminate by themselves. See
    # https://github.com/tensorflow/tensorflow/issues/4713.
    if task_type not in ['ps']:
        thread.join()
        _logger.info(f"{task_type}:{task_id} {thread.state}")

    _task_commons._shutdown_container(client, cluster_tasks, session_config,
                                      thread)
Esempio n. 2
0
def test_reserve_sock_addr():
    with reserve_sock_addr() as (host, port):
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        with pytest.raises(OSError) as exc_info:
            sock.bind((host, port))

        # Ensure that the iterator holds the sockets open.
        assert exc_info.value.errno == errno.EADDRINUSE
def main() -> None:
    task_type, task_id = cluster.get_task_description()
    with reserve_sock_addr() as host_port:
        client, cluster_spec, cluster_tasks = _prepare_container(host_port)
        cluster.setup_tf_config(cluster_spec)
        tf_session_config = cloudpickle.loads(client.kv.wait(KV_TF_SESSION_CONFIG))
        tf.logging.info(f"tf_server_conf {tf_session_config}")

    tf.contrib.distribute.run_standard_tensorflow_server()
    event.wait(client, "stop")
Esempio n. 4
0
def _setup_master(client: skein.ApplicationClient, rank: int) -> None:
    if rank == 0:
        with _internal.reserve_sock_addr() as host_port:
            event.broadcast(client, MASTER_ADDR, host_port[0])
            event.broadcast(client, MASTER_PORT, str(host_port[1]))
            os.environ[MASTER_ADDR] = host_port[0]
            os.environ[MASTER_PORT] = str(host_port[1])
    else:
        master_addr = event.wait(client, MASTER_ADDR)
        master_port = event.wait(client, MASTER_PORT)
        os.environ[MASTER_ADDR] = master_addr
        os.environ[MASTER_PORT] = master_port
def main() -> None:
    task_type, task_id = cluster.get_task_description()
    with reserve_sock_addr() as host_port:
        client, cluster_spec, cluster_tasks = _prepare_container(host_port)
        # Variable TF_CONFIG must be set before instantiating
        # the estimator to train in a distributed way
        cluster.setup_tf_config(cluster_spec)
        experiment = _get_experiment(client)
        run_config = experiment.config
        tf.logging.info(f"Starting server {task_type}:{task_id}")

    cluster.start_tf_server(cluster_spec, run_config.session_config)
    thread = _execute_dispatched_function(client, experiment)

    # "ps" tasks do not terminate by themselves. See
    # https://github.com/tensorflow/tensorflow/issues/4713.
    # Tensorboard is terminated after all other tasks in _shutdown_container
    if task_type not in ['ps', 'tensorboard']:
        thread.join()
        tf.logging.info(f"{task_type}:{task_id} {thread.state}")

    _shutdown_container(client, cluster_tasks, run_config, thread)
Esempio n. 6
0
def start_tf_board(client: skein.ApplicationClient, tf_board_model_dir: str):
    task = cluster.get_task()
    os.environ['GCS_READ_CACHE_DISABLED'] = '1'
    os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp'
    os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION'] = '2'
    try:
        program.setup_environment()
        tensorboard = program.TensorBoard()
        with _internal.reserve_sock_addr() as (h, p):
            tensorboard_url = f"http://{h}:{p}"
            argv = ['tensorboard', f"--logdir={tf_board_model_dir}",
                    f"--port={p}"]
            tb_extra_args = os.getenv('TB_EXTRA_ARGS', "")
            if tb_extra_args:
                argv += tb_extra_args.split(' ')
            tensorboard.configure(argv)
        tensorboard.launch()
        event.start_event(client, task)
        event.url_event(client, task, f"{tensorboard_url}")
    except Exception as e:
        _logger.error("Cannot start tensorboard", e)
        event.stop_event(client, task, e)