Example #1
0
def tensorflow_scheduler(global_future,
                         model_name,
                         client=None,
                         tf_option=None,
                         tf_port=None,
                         **tf_cluster_spec):
    scheduler_info = yield client.scheduler.identity()
    cuda_free_map = yield client.run(cuda_free_indexes)
    tf_configs = tensorflow_gen_config(free_node_name_map=cuda_free_map,
                                       **tf_cluster_spec)

    logger.info('Model Schedule %s: \n  tf_configs:%s\n\n', model_name,
                tf_configs)

    tf_option = tf_option if isinstance(tf_option, (str, bytes)) else (
        tf_option.SerializeToString() if tf_option else tf_option)

    chief_configs, ps_configs, other_configs = [], [], []
    for tf_config in tf_configs:
        task_type = tf_config['task']['type']
        task_index = tf_config['task']['index']
        if task_type in ('chief', 'master'):
            chief_configs.append(tf_config)
        elif task_type in ('ps', ):
            ps_configs.append(tf_config)
        else:
            other_configs.append(tf_config)

    s_time = time.time()
    dt = datetime.datetime.now()

    chief_configs.sort(key=lambda cfg: cfg['task']['index'])
    ps_configs.sort(key=lambda cfg: cfg['task']['index'])
    other_configs.sort(key=lambda cfg: cfg['task']['index'])

    client.loop.set_default_executor(
        ThreadPoolExecutor(max_workers=len(tf_configs)))

    result_future = Future()
    result_future.tf_configs = tf_configs
    result_future.tf_option = tf_option
    result_future.cuda_map = cuda_free_map

    chief_future = Future()
    client.loop.add_callback(startup_actors, scheduler_info, client,
                             model_name, tf_option, tf_configs, chief_future)
    chief_actors = yield chief_future

    sorted_task_keys = list(
        sorted(chief_actors.keys(), key=lambda x: dask_sork_key(x)))

    sub = Sub(model_name, client=client)
    pubs = {k: Pub(model_name, client=client) for k in sorted_task_keys}
    scheduler_info = yield client.scheduler.identity(
    )  # data flush sync between this client and scheduler

    def chief_finish(task_key, actor, fu):
        value = fu.result()
        logger.info('Tensorflow Finished[%s/%s], key:%s, val:%s',
                    len(chief_actors), len(tf_configs), task_key, value)
        chief_actors[task_key] = actor
        if len(chief_actors) == len(tf_configs):
            logger.info('Tensorflow Cluster All Finished: %s',
                        chief_actors.keys())

    # Chief First.
    msgs = {}
    chief_key_actor = sorted_task_keys[0]

    while (len(msgs) + 1) < len(chief_actors):
        msg = yield sub._get()
        logger.info('Sub Rcv %s:%s', type(msg), msg)
        msgs.update(msg)

    import pdb
    pdb.set_trace()
    #    A = yield chief_actor.get_result()
    assert chief_key_actor in msgs, 'Tensorflow Chief Task Required: %s' % chief_key_actor
    time.sleep(1)
    future = yield model_cleanup(client, model_name)
    import pdb
    pdb.set_trace()
    logger.info("Tensorflow Task clean, %s", chief_actors)
    global_future.set_result(chief_actors)