コード例 #1
0
def _shutdown_container(
    client: skein.ApplicationClient,
    cluster_tasks: List[str],
    run_config: tf.estimator.RunConfig,
    thread: Optional[MonitoredThread]
) -> None:
    # Wait for all tasks connected to this one. The set of tasks to
    # wait for contains all tasks in the cluster, or the ones
    # matching ``device_filters`` if set. The implementation assumes
    # that ``device_filers`` are symmetric.
    exception = thread.exception if thread is not None and isinstance(thread, MonitoredThread) \
        else None
    task = cluster.get_task()
    event.stop_event(client, task, exception)
    if cluster_tasks is None:
        tasks = None
    else:
        tasks = [c for c in cluster_tasks if not c.startswith('tensorboard')]
    wait_for_connected_tasks(
        client,
        tasks,
        getattr(run_config.session_config, "device_filters", []))

    if task.startswith('tensorboard'):
        timeout = get_termination_timeout()
        if thread is not None:
            thread.join(timeout)
        tf.logging.info(f"{task} finished")
    event.broadcast_container_stop_time(client, task)

    if exception is not None:
        raise exception from None
コード例 #2
0
def main() -> None:
    _task_commons._log_sys_info()
    task_type, task_id = cluster.get_task_description()
    task = cluster.get_task()
    client = skein.ApplicationClient.from_current()

    _task_commons._setup_container_logs(client)
    cluster_tasks = _task_commons._get_cluster_tasks(client)

    model_dir = os.getenv('TB_MODEL_DIR', "")
    if not model_dir:
        _logger.info("Read model_dir from estimator config")
        experiment = _task_commons._get_experiment(client)
        model_dir = experiment.estimator.config.model_dir

    _logger.info(f"Starting tensorboard on {model_dir}")

    thread = _internal.MonitoredThread(name=f"{task_type}:{task_id}",
                                       target=tensorboard.start_tf_board,
                                       args=(client, model_dir),
                                       daemon=True)
    thread.start()

    for cluster_task in cluster_tasks:
        event.wait(client, f"{cluster_task}/stop")

    timeout = tensorboard.get_termination_timeout()
    thread.join(timeout)

    event.stop_event(client, task, thread.exception)
    event.broadcast_container_stop_time(client, task)
コード例 #3
0
def _shutdown_container(client: skein.ApplicationClient,
                        cluster_tasks: List[str],
                        session_config: tf.compat.v1.ConfigProto,
                        thread: Optional[MonitoredThread]) -> None:
    # Wait for all tasks connected to this one. The set of tasks to
    # wait for contains all tasks in the cluster, or the ones
    # matching ``device_filters`` if set. The implementation assumes
    # that ``device_filers`` are symmetric.
    exception = thread.exception if thread is not None and isinstance(thread, MonitoredThread) \
        else None
    task = get_task()
    event.stop_event(client, task, exception)
    _wait_for_connected_tasks(client, cluster_tasks,
                              getattr(session_config, "device_filters", []))

    event.broadcast_container_stop_time(client, task)

    if exception is not None:
        raise exception from None