def create_cluster(): client = skein.ApplicationClient.from_current() cluster_spec = cluster.start_cluster(client, [f'{NODE_NAME}:0', f'{NODE_NAME}:1']) cluster.setup_tf_config(cluster_spec) cluster.start_tf_server(cluster_spec) event.wait(client, "stop")
def test_start_cluster_worker(task_name, task_index): task = f"{task_name}:{task_index}" CLUSTER_SPEC = {"worker:0/init": [f"{WORKER0_HOST}:{WORKER0_PORT}"], f"{task}/init": [f"{CURRENT_HOST}:{CURRENT_PORT}"]} with contextlib.ExitStack() as stack: stack.enter_context(mock.patch.dict(os.environ)) mock_event = stack.enter_context(mock.patch(f"{MODULE_TO_TEST}.event")) os.environ["SKEIN_CONTAINER_ID"] = f"{task_name}_{task_index}" mock_event.wait.side_effect = lambda client, key: CLUSTER_SPEC[key][0] mock_client = mock.Mock(spec=skein.ApplicationClient) cluster.start_cluster((CURRENT_HOST, CURRENT_PORT), mock_client, [task, "worker:0"]) mock_event.init_event.assert_called_once_with(mock_client, task, f"{CURRENT_HOST}:{CURRENT_PORT}")
def _prepare_container( host_port: Tuple[str, int] ) -> Tuple[skein.ApplicationClient, Dict[str, List[str]], List[str]]: """Keep socket open while preparing container """ client = skein.ApplicationClient.from_current() _setup_container_logs(client) cluster_tasks = _get_cluster_tasks(client) cluster_spec = cluster.start_cluster(host_port, client, cluster_tasks) return client, cluster_spec, cluster_tasks
def _prepare_container( ) -> Tuple[skein.ApplicationClient, Dict[str, List[str]], List[str]]: tf.logging.info("Python " + sys.version) tf.logging.info("Skein " + skein.__version__) tf.logging.info(f"TensorFlow {tf.GIT_VERSION} {tf.VERSION}") client = skein.ApplicationClient.from_current() _setup_container_logs(client) cluster_tasks = list( iter_tasks(json.loads(client.kv.wait(KV_CLUSTER_INSTANCES).decode()))) cluster_spec = cluster.start_cluster(client, cluster_tasks) return client, cluster_spec, cluster_tasks