def main() -> None: _task_commons._log_sys_info() task_type, task_id = cluster.get_task_description() with _internal.reserve_sock_addr() as host_port: client, cluster_spec, cluster_tasks = _task_commons._prepare_container( host_port) # Variable TF_CONFIG must be set before instantiating # the estimator to train in a distributed way cluster.setup_tf_config(cluster_spec) experiment = _task_commons._get_experiment(client) if isinstance(experiment, Experiment): session_config = experiment.config.session_config elif isinstance(experiment, KerasExperiment): raise ValueError( "KerasExperiment using parameter strategy is unsupported") else: raise ValueError( "experiment must be an Experiment or a KerasExperiment") _logger.info(f"Starting server {task_type}:{task_id}") cluster.start_tf_server(cluster_spec, session_config) thread = _task_commons._execute_dispatched_function(client, experiment) # "ps" tasks do not terminate by themselves. See # https://github.com/tensorflow/tensorflow/issues/4713. if task_type not in ['ps']: thread.join() _logger.info(f"{task_type}:{task_id} {thread.state}") _task_commons._shutdown_container(client, cluster_tasks, session_config, thread)
def test_start_tf_server(task_name, task_index, is_server_started): CLUSTER_SPEC = { "worker": [ f"worker0.{WORKER0_HOST}:{WORKER0_PORT}", f"worker1.{WORKER1_HOST}:{WORKER1_PORT}" ], "ps": [f"ps0.{CURRENT_HOST}:{CURRENT_PORT}"] } with contextlib.ExitStack() as stack: stack.enter_context(mock.patch.dict(os.environ)) os.environ["SKEIN_CONTAINER_ID"] = f"{task_name}_{task_index}" mock_server = stack.enter_context( mock.patch(f"{MODULE_TO_TEST}.tf.distribute")) cluster.start_tf_server(CLUSTER_SPEC) if is_server_started: assert mock_server.Server.call_count == 1 _, kwargs = mock_server.Server.call_args assert kwargs["job_name"] == task_name assert kwargs["task_index"] == task_index assert kwargs["start"] is True else: assert mock_server.Server.call_count == 0
def create_cluster(): client = skein.ApplicationClient.from_current() cluster_spec = cluster.start_cluster(client, [f'{NODE_NAME}:0', f'{NODE_NAME}:1']) cluster.setup_tf_config(cluster_spec) cluster.start_tf_server(cluster_spec) event.wait(client, "stop")