Beispiel #1
0
def create_cluster():
    client = skein.ApplicationClient.from_current()
    cluster_spec = cluster.start_cluster(client,
                                         [f'{NODE_NAME}:0', f'{NODE_NAME}:1'])
    cluster.setup_tf_config(cluster_spec)
    cluster.start_tf_server(cluster_spec)
    event.wait(client, "stop")
Beispiel #2
0
def test_start_cluster_worker(task_name, task_index):
    task = f"{task_name}:{task_index}"

    CLUSTER_SPEC = {"worker:0/init": [f"{WORKER0_HOST}:{WORKER0_PORT}"],
                    f"{task}/init": [f"{CURRENT_HOST}:{CURRENT_PORT}"]}

    with contextlib.ExitStack() as stack:
        stack.enter_context(mock.patch.dict(os.environ))
        mock_event = stack.enter_context(mock.patch(f"{MODULE_TO_TEST}.event"))

        os.environ["SKEIN_CONTAINER_ID"] = f"{task_name}_{task_index}"

        mock_event.wait.side_effect = lambda client, key: CLUSTER_SPEC[key][0]
        mock_client = mock.Mock(spec=skein.ApplicationClient)
        cluster.start_cluster((CURRENT_HOST, CURRENT_PORT), mock_client, [task, "worker:0"])
        mock_event.init_event.assert_called_once_with(mock_client, task,
                                                      f"{CURRENT_HOST}:{CURRENT_PORT}")
Beispiel #3
0
def _prepare_container(
    host_port: Tuple[str, int]
) -> Tuple[skein.ApplicationClient, Dict[str, List[str]], List[str]]:
    """Keep socket open while preparing container """
    client = skein.ApplicationClient.from_current()
    _setup_container_logs(client)
    cluster_tasks = _get_cluster_tasks(client)
    cluster_spec = cluster.start_cluster(host_port, client, cluster_tasks)
    return client, cluster_spec, cluster_tasks
Beispiel #4
0
def _prepare_container(
) -> Tuple[skein.ApplicationClient, Dict[str, List[str]], List[str]]:
    tf.logging.info("Python " + sys.version)
    tf.logging.info("Skein " + skein.__version__)
    tf.logging.info(f"TensorFlow {tf.GIT_VERSION} {tf.VERSION}")
    client = skein.ApplicationClient.from_current()
    _setup_container_logs(client)
    cluster_tasks = list(
        iter_tasks(json.loads(client.kv.wait(KV_CLUSTER_INSTANCES).decode())))
    cluster_spec = cluster.start_cluster(client, cluster_tasks)
    return client, cluster_spec, cluster_tasks