def tensorflow_scheduler(global_future, model_name, client=None, tf_option=None, tf_port=None, **tf_cluster_spec): scheduler_info = yield client.scheduler.identity() cuda_free_map = yield client.run(cuda_free_indexes) tf_configs = tensorflow_gen_config(free_node_name_map=cuda_free_map, **tf_cluster_spec) logger.info('Model Schedule %s: \n tf_configs:%s\n\n', model_name, tf_configs) tf_option = tf_option if isinstance(tf_option, (str, bytes)) else ( tf_option.SerializeToString() if tf_option else tf_option) chief_configs, ps_configs, other_configs = [], [], [] for tf_config in tf_configs: task_type = tf_config['task']['type'] task_index = tf_config['task']['index'] if task_type in ('chief', 'master'): chief_configs.append(tf_config) elif task_type in ('ps', ): ps_configs.append(tf_config) else: other_configs.append(tf_config) s_time = time.time() dt = datetime.datetime.now() chief_configs.sort(key=lambda cfg: cfg['task']['index']) ps_configs.sort(key=lambda cfg: cfg['task']['index']) other_configs.sort(key=lambda cfg: cfg['task']['index']) client.loop.set_default_executor( ThreadPoolExecutor(max_workers=len(tf_configs))) result_future = Future() result_future.tf_configs = tf_configs result_future.tf_option = tf_option result_future.cuda_map = cuda_free_map chief_future = Future() client.loop.add_callback(startup_actors, scheduler_info, client, model_name, tf_option, tf_configs, chief_future) chief_actors = yield chief_future sorted_task_keys = list( sorted(chief_actors.keys(), key=lambda x: dask_sork_key(x))) sub = Sub(model_name, client=client) pubs = {k: Pub(model_name, client=client) for k in sorted_task_keys} scheduler_info = yield client.scheduler.identity( ) # data flush sync between this client and scheduler def chief_finish(task_key, actor, fu): value = fu.result() logger.info('Tensorflow Finished[%s/%s], key:%s, val:%s', len(chief_actors), len(tf_configs), task_key, value) chief_actors[task_key] = actor if len(chief_actors) == len(tf_configs): logger.info('Tensorflow Cluster All Finished: %s', chief_actors.keys()) # Chief First. msgs = {} chief_key_actor = sorted_task_keys[0] while (len(msgs) + 1) < len(chief_actors): msg = yield sub._get() logger.info('Sub Rcv %s:%s', type(msg), msg) msgs.update(msg) import pdb pdb.set_trace() # A = yield chief_actor.get_result() assert chief_key_actor in msgs, 'Tensorflow Chief Task Required: %s' % chief_key_actor time.sleep(1) future = yield model_cleanup(client, model_name) import pdb pdb.set_trace() logger.info("Tensorflow Task clean, %s", chief_actors) global_future.set_result(chief_actors)