def main() -> None: _task_commons._log_sys_info() task_type, task_id = cluster.get_task_description() with _internal.reserve_sock_addr() as host_port: client, cluster_spec, cluster_tasks = _task_commons._prepare_container( host_port) # Variable TF_CONFIG must be set before instantiating # the estimator to train in a distributed way cluster.setup_tf_config(cluster_spec) experiment = _task_commons._get_experiment(client) if isinstance(experiment, Experiment): session_config = experiment.config.session_config elif isinstance(experiment, KerasExperiment): raise ValueError( "KerasExperiment using parameter strategy is unsupported") else: raise ValueError( "experiment must be an Experiment or a KerasExperiment") _logger.info(f"Starting server {task_type}:{task_id}") cluster.start_tf_server(cluster_spec, session_config) thread = _task_commons._execute_dispatched_function(client, experiment) # "ps" tasks do not terminate by themselves. See # https://github.com/tensorflow/tensorflow/issues/4713. if task_type not in ['ps']: thread.join() _logger.info(f"{task_type}:{task_id} {thread.state}") _task_commons._shutdown_container(client, cluster_tasks, session_config, thread)
def main() -> None: _task_commons._log_sys_info() task_type, task_id = cluster.get_task_description() task = cluster.get_task() client = skein.ApplicationClient.from_current() _task_commons._setup_container_logs(client) cluster_tasks = _task_commons._get_cluster_tasks(client) model_dir = os.getenv('TB_MODEL_DIR', "") if not model_dir: _logger.info("Read model_dir from estimator config") experiment = _task_commons._get_experiment(client) model_dir = experiment.estimator.config.model_dir _logger.info(f"Starting tensorboard on {model_dir}") thread = _internal.MonitoredThread(name=f"{task_type}:{task_id}", target=tensorboard.start_tf_board, args=(client, model_dir), daemon=True) thread.start() for cluster_task in cluster_tasks: event.wait(client, f"{cluster_task}/stop") timeout = tensorboard.get_termination_timeout() thread.join(timeout) event.stop_event(client, task, thread.exception) event.broadcast_container_stop_time(client, task)
def main() -> None: _task_commons._log_sys_info() task_type, task_id = cluster.get_task_description() with _internal.reserve_sock_addr() as host_port: client, cluster_spec, cluster_tasks = _task_commons._prepare_container(host_port) cluster.setup_tf_config(cluster_spec) tf_session_config = cloudpickle.loads(client.kv.wait(constants.KV_TF_SESSION_CONFIG)) _logger.info(f"tf_server_conf {tf_session_config}") tf.contrib.distribute.run_standard_tensorflow_server() event.wait(client, "stop")
def _execute_dispatched_function(client: skein.ApplicationClient, experiment: Experiment) -> MonitoredThread: task_type, task_id = cluster.get_task_description() _logger.info(f"Starting execution {task_type}:{task_id}") thread = MonitoredThread(name=f"{task_type}:{task_id}", target=_gen_monitored_train_and_evaluate(client), args=tuple(experiment), daemon=True) thread.start() task = cluster.get_task() event.start_event(client, task) return thread
def main(): client = skein.ApplicationClient.from_current() task = cluster.get_task() task_type, task_id = cluster.get_task_description() event.init_event(client, task, "127.0.0.1:0") _task_commons._setup_container_logs(client) if task_type == "evaluator": evaluator_fn(client) else: logger.info(f"{task_type}:{task_id}: nothing to do") event.stop_event(client, task, None)
def main(): client = skein.ApplicationClient.from_current() task_type, task_id = cluster.get_task_description() task = cluster.get_task() event.init_event(client, task, f"127.0.0.1:0") _task_commons._setup_container_logs(client) if task_type in ['chief', 'worker']: _worker_fn(task_type, task_id, client) elif task_type == 'evaluator': _evaluator_fn(client) else: logger.error(f'Unknown task type {task_type}') event.stop_event(client, task, None)
def _execute_dispatched_function( client: skein.ApplicationClient, experiment: Union[Experiment, KerasExperiment]) -> MonitoredThread: task_type, task_id = cluster.get_task_description() _logger.info(f"Starting execution {task_type}:{task_id}") if isinstance(experiment, Experiment): thread = MonitoredThread( name=f"{task_type}:{task_id}", target=_gen_monitored_train_and_evaluate(client), args=tuple(experiment), daemon=True) elif isinstance(experiment, KerasExperiment): raise ValueError( "KerasExperiment using parameter strategy is unsupported") else: raise ValueError( "experiment must be an Experiment or a KerasExperiment") thread.start() task = cluster.get_task() event.start_event(client, task) return thread
def test_get_task_description(): with mock.patch.dict(os.environ): os.environ["SKEIN_CONTAINER_ID"] = "MYTASK_42" assert "MYTASK", 42 == cluster.get_task_description()