Beispiel #1
0
def main() -> None:
    _log_sys_info()
    task_type, task_id = get_task_description()

    client = skein.ApplicationClient.from_current()
    experiment = _get_experiment(client)
    assert isinstance(experiment, PytorchExperiment)
    cluster_tasks = _get_cluster_tasks(client)
    n_workers_per_executor = experiment.n_workers_per_executor

    world_size = len([t for t in cluster_tasks if "worker" in t
                      ]) * n_workers_per_executor
    _logger.info(f"Task type: {task_type}; Task id: {task_id};"
                 f"World_size: {world_size}: Cluster tasks: {cluster_tasks}")

    if n_workers_per_executor > 1:
        workers = list()
        mp.set_start_method("spawn", force=True)
        for n in range(n_workers_per_executor):
            worker = mp.Process(
                target=_train,
                args=(_get_device(n), (task_id * n_workers_per_executor) + n,
                      world_size,
                      _get_collective_ops_backend(n_workers_per_executor)))
            worker.start()
            workers.append(worker)

        for worker in workers:
            worker.join()
    else:
        _train(0, task_id, world_size, "nccl")
Beispiel #2
0
def main():
    client = skein.ApplicationClient.from_current()
    task = get_task()
    task_type, task_id = get_task_description()
    event.init_event(client, task, "127.0.0.1:0")
    _task_commons._setup_container_logs(client)

    if task_type == "evaluator":
        evaluator_fn(client)
    else:
        logger.info(f"{task_type}:{task_id}: nothing to do")

    event.stop_event(client, task, None)
Beispiel #3
0
def start_tf_server(
    cluster_spec: typing.Dict[str, typing.List[str]],
    session_config: tf.compat.v1.ConfigProto = None
) -> typing.Optional[tf.distribute.Server]:

    task_type, task_id = get_task_description()
    if _is_fake_google_env(task_type) and cluster_spec:
        server = tf.distribute.Server(tf.train.ClusterSpec(cluster_spec),
                                      job_name=task_type,
                                      task_index=task_id,
                                      config=session_config,
                                      start=True)
        return server
    return None
Beispiel #4
0
def setup_tf_config(cluster_spec):
    # Note that "evaluator" does not need a cluster, and "ps" (!)
    # surprisingly does not follow the same code path as the rest
    # and spawns a server regardless of the "environment" value.
    task_type, task_id = get_task_description()
    _internal.xset_environ(TF_CONFIG=json.dumps(
        {
            "cluster": cluster_spec,
            "environment": "google" if _is_fake_google_env(task_type) else "",
            "task": {
                "type": task_type,
                "index": task_id
            },
        }))
Beispiel #5
0
def _tensorboard(
        tensorboard_dir: str,
        client: skein.ApplicationClient) -> Generator[None, None, None]:
    task_type, task_id = get_task_description()

    thread = _internal.MonitoredThread(name=f"{task_type}:{task_id}",
                                       target=tensorboard.start_tf_board,
                                       args=(client, tensorboard_dir),
                                       daemon=True)
    thread.start()

    yield

    timeout = tensorboard.get_termination_timeout()
    thread.join(timeout)
Beispiel #6
0
def main():
    client = skein.ApplicationClient.from_current()
    task_type, task_id = get_task_description()
    task = get_task()
    event.init_event(client, task, "127.0.0.1:0")
    _task_commons._setup_container_logs(client)
    net_if = get_net_if()

    if task_type == 'chief':
        _driver_fn(client, net_if)
    if task_type in ['worker', 'chief']:
        _worker_fn(client, task, net_if)
    elif task_type == 'evaluator':
        evaluator_fn(client)
    else:
        logger.error(f'Unknown task type {task_type}')

    event.stop_event(client, task, None)
Beispiel #7
0
def _execute_dispatched_function(
        client: skein.ApplicationClient,
        experiment: Union[Experiment, KerasExperiment]) -> MonitoredThread:
    task_type, task_id = get_task_description()
    _logger.info(f"Starting execution {task_type}:{task_id}")
    if isinstance(experiment, Experiment):
        thread = MonitoredThread(
            name=f"{task_type}:{task_id}",
            target=_gen_monitored_train_and_evaluate(client),
            args=tuple(experiment),
            daemon=True)
    elif isinstance(experiment, KerasExperiment):
        raise ValueError(
            "KerasExperiment using parameter strategy is unsupported")
    else:
        raise ValueError(
            "experiment must be an Experiment or a KerasExperiment")
    thread.start()
    task = get_task()
    event.start_event(client, task)
    return thread
Beispiel #8
0
def main() -> None:
    _log_sys_info()
    task_type, task_id = get_task_description()
    task = get_task()
    client = skein.ApplicationClient.from_current()

    _setup_container_logs(client)
    cluster_tasks = _get_cluster_tasks(client)

    model_dir = os.getenv('TB_MODEL_DIR', "")
    if not model_dir:
        _logger.info("Read model_dir from estimator config")
        experiment = _get_experiment(client)
        if isinstance(experiment, Experiment):
            model_dir = experiment.estimator.config.model_dir
        elif isinstance(experiment, KerasExperiment):
            model_dir = experiment.model_dir
        else:
            raise ValueError("experiment must be an Experiment or a KerasExperiment")

    _logger.info(f"Starting tensorboard on {model_dir}")

    thread = _internal.MonitoredThread(
        name=f"{task_type}:{task_id}",
        target=tensorboard.start_tf_board,
        args=(client, model_dir),
        daemon=True)
    thread.start()

    for cluster_task in cluster_tasks:
        event.wait(client, f"{cluster_task}/stop")

    timeout = tensorboard.get_termination_timeout()
    thread.join(timeout)

    event.stop_event(client, task, thread.exception)
    event.broadcast_container_stop_time(client, task)
Beispiel #9
0
def test_get_task_description():
    with mock.patch.dict(os.environ):
        os.environ["SKEIN_CONTAINER_ID"] = "MYTASK_42"
        assert "MYTASK", 42 == get_task_description()