Ejemplo n.º 1
0
def _shutdown_container(
    client: skein.ApplicationClient,
    cluster_tasks: List[str],
    run_config: tf.estimator.RunConfig,
    thread: Optional[MonitoredThread]
) -> None:
    # Wait for all tasks connected to this one. The set of tasks to
    # wait for contains all tasks in the cluster, or the ones
    # matching ``device_filters`` if set. The implementation assumes
    # that ``device_filers`` are symmetric.
    exception = thread.exception if thread is not None and isinstance(thread, MonitoredThread) \
        else None
    task = cluster.get_task()
    event.stop_event(client, task, exception)
    if cluster_tasks is None:
        tasks = None
    else:
        tasks = [c for c in cluster_tasks if not c.startswith('tensorboard')]
    wait_for_connected_tasks(
        client,
        tasks,
        getattr(run_config.session_config, "device_filters", []))

    if task.startswith('tensorboard'):
        timeout = get_termination_timeout()
        if thread is not None:
            thread.join(timeout)
        tf.logging.info(f"{task} finished")
    event.broadcast_container_stop_time(client, task)

    if exception is not None:
        raise exception from None
Ejemplo n.º 2
0
def main() -> None:
    _task_commons._log_sys_info()
    task_type, task_id = cluster.get_task_description()
    task = cluster.get_task()
    client = skein.ApplicationClient.from_current()

    _task_commons._setup_container_logs(client)
    cluster_tasks = _task_commons._get_cluster_tasks(client)

    model_dir = os.getenv('TB_MODEL_DIR', "")
    if not model_dir:
        _logger.info("Read model_dir from estimator config")
        experiment = _task_commons._get_experiment(client)
        model_dir = experiment.estimator.config.model_dir

    _logger.info(f"Starting tensorboard on {model_dir}")

    thread = _internal.MonitoredThread(name=f"{task_type}:{task_id}",
                                       target=tensorboard.start_tf_board,
                                       args=(client, model_dir),
                                       daemon=True)
    thread.start()

    for cluster_task in cluster_tasks:
        event.wait(client, f"{cluster_task}/stop")

    timeout = tensorboard.get_termination_timeout()
    thread.join(timeout)

    event.stop_event(client, task, thread.exception)
    event.broadcast_container_stop_time(client, task)
Ejemplo n.º 3
0
 def __init__(self):
     self.client = skein.ApplicationClient.from_current()
     self.task = get_task()
     self.step_counter = 0
     self.eval_start_time = 0.0
     self.eval_step_dur_accu = 0.0
     self.start_time = time.time()
Ejemplo n.º 4
0
def start_tf_board(client: skein.ApplicationClient,
                   experiment: Experiment = None):
    thread = None
    if experiment:
        model_dir = experiment.estimator.config.model_dir
    else:
        model_dir = os.environ.get('TF_BOARD_MODEL_DIR', None)
    task = cluster.get_task()
    os.environ['GCS_READ_CACHE_DISABLED'] = '1'
    os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp'
    os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION'] = '2'
    try:
        program.setup_environment()
        tensorboard = program.TensorBoard(default.get_plugins(),
                                          default.get_assets_zip_provider())
        with _internal.reserve_sock_addr() as (h, p):
            tensorboard_url = f"http://{h}:{p}"
            argv = ['tensorboard', f"--logdir={model_dir}",
                    f"--port={p}"]
            # Append more arguments if needed.
            if 'TF_BOARD_EXTRA_ARGS' in os.environ:
                argv += os.environ['TF_BOARD_EXTRA_ARGS'].split(' ')
            tensorboard.configure(argv)
        tensorboard.launch()
        event.url_event(client, task, f"Tensorboard is listening at {tensorboard_url}")
        thread = [t for t in threading.enumerate() if t.name == 'TensorBoard'][0]
    except Exception as e:
        event.stop_event(client, task, e)

    return thread
Ejemplo n.º 5
0
def _setup_container_logs(client):
    task = cluster.get_task()
    event.broadcast_container_start_time(client, task)
    container = next(c for c in client.get_containers()
                     if c.yarn_container_id == os.environ["CONTAINER_ID"])
    logs = container.yarn_container_logs
    if logs is not None and not logs.startswith("http://"):
        logs = "http://" + logs
    event.logs_event(client, task, logs)
Ejemplo n.º 6
0
def _get_experiment(client: skein.ApplicationClient) -> Experiment:
    try:
        experiment = dill.loads(client.kv.wait(KV_EXPERIMENT_FN))()
    except Exception as e:
        task = cluster.get_task()
        event.start_event(client, task)
        event.stop_event(client, task, e)
        raise
    return experiment
Ejemplo n.º 7
0
def _get_experiment(client: skein.ApplicationClient) -> NamedTuple:
    try:
        experiment = cloudpickle.loads(
            client.kv.wait(constants.KV_EXPERIMENT_FN))()
    except Exception as e:
        task = cluster.get_task()
        event.start_event(client, task)
        event.stop_event(client, task, e)
        raise
    return experiment
Ejemplo n.º 8
0
def _gen_monitored_train_and_evaluate(client: skein.ApplicationClient):
    task = cluster.get_task()

    def train_and_evaluate(estimator: tf.estimator,
                           train_spec: tf.estimator.TrainSpec,
                           eval_spec: tf.estimator.EvalSpec):
        event.broadcast_train_eval_start_timer(client, task)
        tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
        event.broadcast_train_eval_stop_timer(client, task)

    return train_and_evaluate
Ejemplo n.º 9
0
def _execute_dispatched_function(client: skein.ApplicationClient,
                                 experiment: Experiment) -> MonitoredThread:
    task_type, task_id = cluster.get_task_description()
    _logger.info(f"Starting execution {task_type}:{task_id}")
    thread = MonitoredThread(name=f"{task_type}:{task_id}",
                             target=_gen_monitored_train_and_evaluate(client),
                             args=tuple(experiment),
                             daemon=True)
    thread.start()
    task = cluster.get_task()
    event.start_event(client, task)
    return thread
Ejemplo n.º 10
0
def main():
    client = skein.ApplicationClient.from_current()
    task = cluster.get_task()
    task_type, task_id = cluster.get_task_description()
    event.init_event(client, task, "127.0.0.1:0")
    _task_commons._setup_container_logs(client)

    if task_type == "evaluator":
        evaluator_fn(client)
    else:
        logger.info(f"{task_type}:{task_id}: nothing to do")

    event.stop_event(client, task, None)
Ejemplo n.º 11
0
def main():
    client = skein.ApplicationClient.from_current()
    task_type, task_id = cluster.get_task_description()
    task = cluster.get_task()
    event.init_event(client, task, f"127.0.0.1:0")
    _task_commons._setup_container_logs(client)

    if task_type in ['chief', 'worker']:
        _worker_fn(task_type, task_id, client)
    elif task_type == 'evaluator':
        _evaluator_fn(client)
    else:
        logger.error(f'Unknown task type {task_type}')

    event.stop_event(client, task, None)
Ejemplo n.º 12
0
def _shutdown_container(client: skein.ApplicationClient,
                        cluster_tasks: List[str],
                        session_config: tf.compat.v1.ConfigProto,
                        thread: Optional[MonitoredThread]) -> None:
    # Wait for all tasks connected to this one. The set of tasks to
    # wait for contains all tasks in the cluster, or the ones
    # matching ``device_filters`` if set. The implementation assumes
    # that ``device_filers`` are symmetric.
    exception = thread.exception if thread is not None and isinstance(thread, MonitoredThread) \
        else None
    task = cluster.get_task()
    event.stop_event(client, task, exception)
    wait_for_connected_tasks(client, cluster_tasks,
                             getattr(session_config, "device_filters", []))

    event.broadcast_container_stop_time(client, task)

    if exception is not None:
        raise exception from None
Ejemplo n.º 13
0
def _execute_dispatched_function(
        client: skein.ApplicationClient,
        experiment: Union[Experiment, KerasExperiment]) -> MonitoredThread:
    task_type, task_id = cluster.get_task_description()
    _logger.info(f"Starting execution {task_type}:{task_id}")
    if isinstance(experiment, Experiment):
        thread = MonitoredThread(
            name=f"{task_type}:{task_id}",
            target=_gen_monitored_train_and_evaluate(client),
            args=tuple(experiment),
            daemon=True)
    elif isinstance(experiment, KerasExperiment):
        raise ValueError(
            "KerasExperiment using parameter strategy is unsupported")
    else:
        raise ValueError(
            "experiment must be an Experiment or a KerasExperiment")
    thread.start()
    task = cluster.get_task()
    event.start_event(client, task)
    return thread
Ejemplo n.º 14
0
def start_tf_board(client: skein.ApplicationClient, tf_board_model_dir: str):
    task = cluster.get_task()
    os.environ['GCS_READ_CACHE_DISABLED'] = '1'
    os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp'
    os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION'] = '2'
    try:
        program.setup_environment()
        tensorboard = program.TensorBoard()
        with _internal.reserve_sock_addr() as (h, p):
            tensorboard_url = f"http://{h}:{p}"
            argv = ['tensorboard', f"--logdir={tf_board_model_dir}",
                    f"--port={p}"]
            tb_extra_args = os.getenv('TB_EXTRA_ARGS', "")
            if tb_extra_args:
                argv += tb_extra_args.split(' ')
            tensorboard.configure(argv)
        tensorboard.launch()
        event.start_event(client, task)
        event.url_event(client, task, f"{tensorboard_url}")
    except Exception as e:
        _logger.error("Cannot start tensorboard", e)
        event.stop_event(client, task, e)