Пример #1
0
def main() -> None:
    _task_commons._log_sys_info()
    task_type, task_id = cluster.get_task_description()
    with _internal.reserve_sock_addr() as host_port:
        client, cluster_spec, cluster_tasks = _task_commons._prepare_container(
            host_port)
        # Variable TF_CONFIG must be set before instantiating
        # the estimator to train in a distributed way
        cluster.setup_tf_config(cluster_spec)
        experiment = _task_commons._get_experiment(client)
        if isinstance(experiment, Experiment):
            session_config = experiment.config.session_config
        elif isinstance(experiment, KerasExperiment):
            raise ValueError(
                "KerasExperiment using parameter strategy is unsupported")
        else:
            raise ValueError(
                "experiment must be an Experiment or a KerasExperiment")
        _logger.info(f"Starting server {task_type}:{task_id}")

    cluster.start_tf_server(cluster_spec, session_config)
    thread = _task_commons._execute_dispatched_function(client, experiment)

    # "ps" tasks do not terminate by themselves. See
    # https://github.com/tensorflow/tensorflow/issues/4713.
    if task_type not in ['ps']:
        thread.join()
        _logger.info(f"{task_type}:{task_id} {thread.state}")

    _task_commons._shutdown_container(client, cluster_tasks, session_config,
                                      thread)
Пример #2
0
def test_start_tf_server(task_name, task_index, is_server_started):

    CLUSTER_SPEC = {
        "worker": [
            f"worker0.{WORKER0_HOST}:{WORKER0_PORT}",
            f"worker1.{WORKER1_HOST}:{WORKER1_PORT}"
        ],
        "ps": [f"ps0.{CURRENT_HOST}:{CURRENT_PORT}"]
    }

    with contextlib.ExitStack() as stack:
        stack.enter_context(mock.patch.dict(os.environ))
        os.environ["SKEIN_CONTAINER_ID"] = f"{task_name}_{task_index}"
        mock_server = stack.enter_context(
            mock.patch(f"{MODULE_TO_TEST}.tf.distribute"))
        cluster.start_tf_server(CLUSTER_SPEC)

        if is_server_started:
            assert mock_server.Server.call_count == 1
            _, kwargs = mock_server.Server.call_args
            assert kwargs["job_name"] == task_name
            assert kwargs["task_index"] == task_index
            assert kwargs["start"] is True
        else:
            assert mock_server.Server.call_count == 0
Пример #3
0
def create_cluster():
    client = skein.ApplicationClient.from_current()
    cluster_spec = cluster.start_cluster(client,
                                         [f'{NODE_NAME}:0', f'{NODE_NAME}:1'])
    cluster.setup_tf_config(cluster_spec)
    cluster.start_tf_server(cluster_spec)
    event.wait(client, "stop")