def _shutdown_container( client: skein.ApplicationClient, cluster_tasks: List[str], run_config: tf.estimator.RunConfig, thread: Optional[MonitoredThread] ) -> None: # Wait for all tasks connected to this one. The set of tasks to # wait for contains all tasks in the cluster, or the ones # matching ``device_filters`` if set. The implementation assumes # that ``device_filers`` are symmetric. exception = thread.exception if thread is not None and isinstance(thread, MonitoredThread) \ else None task = cluster.get_task() event.stop_event(client, task, exception) if cluster_tasks is None: tasks = None else: tasks = [c for c in cluster_tasks if not c.startswith('tensorboard')] wait_for_connected_tasks( client, tasks, getattr(run_config.session_config, "device_filters", [])) if task.startswith('tensorboard'): timeout = get_termination_timeout() if thread is not None: thread.join(timeout) tf.logging.info(f"{task} finished") event.broadcast_container_stop_time(client, task) if exception is not None: raise exception from None
def main() -> None: _task_commons._log_sys_info() task_type, task_id = cluster.get_task_description() task = cluster.get_task() client = skein.ApplicationClient.from_current() _task_commons._setup_container_logs(client) cluster_tasks = _task_commons._get_cluster_tasks(client) model_dir = os.getenv('TB_MODEL_DIR', "") if not model_dir: _logger.info("Read model_dir from estimator config") experiment = _task_commons._get_experiment(client) model_dir = experiment.estimator.config.model_dir _logger.info(f"Starting tensorboard on {model_dir}") thread = _internal.MonitoredThread(name=f"{task_type}:{task_id}", target=tensorboard.start_tf_board, args=(client, model_dir), daemon=True) thread.start() for cluster_task in cluster_tasks: event.wait(client, f"{cluster_task}/stop") timeout = tensorboard.get_termination_timeout() thread.join(timeout) event.stop_event(client, task, thread.exception) event.broadcast_container_stop_time(client, task)
def _shutdown_container(client: skein.ApplicationClient, cluster_tasks: List[str], session_config: tf.compat.v1.ConfigProto, thread: Optional[MonitoredThread]) -> None: # Wait for all tasks connected to this one. The set of tasks to # wait for contains all tasks in the cluster, or the ones # matching ``device_filters`` if set. The implementation assumes # that ``device_filers`` are symmetric. exception = thread.exception if thread is not None and isinstance(thread, MonitoredThread) \ else None task = get_task() event.stop_event(client, task, exception) _wait_for_connected_tasks(client, cluster_tasks, getattr(session_config, "device_filters", [])) event.broadcast_container_stop_time(client, task) if exception is not None: raise exception from None