Exemplo n.º 1
0
    async def publish():
        pub = Pub("a")

        i = 0
        while True:
            await asyncio.sleep(0.01)
            pub._put(i)
            i += 1
Exemplo n.º 2
0
    async def publish():
        pub = Pub('a')

        i = 0
        while True:
            await gen.sleep(0.01)
            pub._put(i)
            i += 1
Exemplo n.º 3
0
    async def publish():
        pub = Pub('a')

        i = 0
        while True:
            await gen.sleep(0.01)
            pub.put(i)
            i += 1
Exemplo n.º 4
0
async def test_repr(c, s, a, b):
    pub = Pub("my-topic")
    sub = Sub("my-topic")
    assert "my-topic" in str(pub)
    assert "Pub" in str(pub)
    assert "my-topic" in str(sub)
    assert "Sub" in str(sub)
Exemplo n.º 5
0
    def pingpong(a, b, start=False, n=1000, msg=1):
        sub = Sub(a)
        pub = Pub(b)

        while not pub.subscribers:
            sleep(0.01)

        if start:
            pub.put(msg)  # other sub may not have started yet

        for i in range(n):
            msg = next(sub)
            pub.put(msg)
            # if i % 100 == 0:
            #     print(a, b, i)
        return n
Exemplo n.º 6
0
def test_client(c, s):
    with pytest.raises(Exception):
        get_worker()
    sub = Sub('a')
    pub = Pub('a')

    sps = s.extensions['pubsub']
    cps = c.extensions['pubsub']

    start = time()
    while not set(sps.client_subscribers['a']) == {c.id}:
        yield gen.sleep(0.01)
        assert time() < start + 3

    pub.put(123)

    result = yield sub.__anext__()
    assert result == 123
Exemplo n.º 7
0
def test_client(c, s):
    with pytest.raises(Exception):
        get_worker()
    sub = Sub("a")
    pub = Pub("a")

    sps = s.extensions["pubsub"]
    cps = c.extensions["pubsub"]

    start = time()
    while not set(sps.client_subscribers["a"]) == {c.id}:
        yield gen.sleep(0.01)
        assert time() < start + 3

    pub.put(123)

    result = yield sub.__anext__()
    assert result == 123
Exemplo n.º 8
0
    def _sampler_worker(i, sampler_class):
        try:
            with threadpool_limits(1):
                _sample_trace = nested_helper(sample_trace, i)

                def logp_and_grad(x):
                    return density.logp_and_grad(x, original_space=False)

                _sampler = sampler_class(logp_and_grad=logp_and_grad,
                                         sample_trace=_sample_trace,
                                         dask_key=dask_key,
                                         process_lock=process_lock)
                t = _sampler.run(n_run, verbose)
                t._samples_original = density.to_original(t.samples)
                t._logp_original = density.to_original_density(
                    t.logp, x_trans=t.samples)
            return t
        except Exception:
            if use_dask:
                pub = Pub(dask_key)
                pub.put(['Error', i])
            raise
Exemplo n.º 9
0
    def pingpong(a, b, start=False, n=1000, msg=1):
        sub = Sub(a)
        pub = Pub(b)

        while not pub.subscribers:
            sleep(0.01)

        if start:
            pub.put(msg)  # other sub may not have started yet

        for i in range(n):
            msg = next(sub)
            pub.put(msg)
            # if i % 100 == 0:
            #     print(a, b, i)
        return n
Exemplo n.º 10
0
 def f(x):
     pub = Pub('a')
     pub.put(x)
Exemplo n.º 11
0
    def __init__(self,
                 key,
                 *args,
                 tf_config=None,
                 tf_option=None,
                 scheduler_info=None,
                 **kwargs):
        # here we made this thread an OWNER of this task by secede from it's ThreadPoolExecutor.
        # NOTE: `thread_pool_secede` ONLY works in NON-coroutine actor_exectutor, ref:worker.actor_execute()
        self.dask_worker = get_worker()
        self.thrid = threading.get_ident()

        thread_pool_secede(adjust=True)
        self.dask_worker.loop.add_callback(self.dask_worker.transition,
                                           thread_state.key, "long-running")

        self.key = key
        self.name = self.dask_worker.name
        self.hostname = socket.gethostname()
        self.address = self.dask_worker.address
        self.scheduler_info = scheduler_info

        model_name = self.key.partition(':')[0]
        self.model_name = model_name[:-4] if model_name.endswith(
            '.zip') else model_name
        self.tf_option = json.loads(tf_option) if isinstance(
            tf_option, str) else tf_option
        self.tf_config = json.loads(tf_config) if isinstance(
            tf_config, str) else tf_config

        self.dask_cwd = os.path.abspath(os.getcwd())
        self.tf_model_pool_dir = os.path.abspath(DASK_MODEL_POOL_DIR)
        self.tf_data_pool_dir = os.path.abspath(DASK_DATA_POOL_DIR)
        self.tf_data_dir = os.path.join(self.tf_data_pool_dir, self.model_name)
        self.tf_config_dir = os.path.join(self.tf_data_dir, 'config')
        self.tf_save_dir = os.path.join(self.tf_data_dir, 'ckpt')
        self.tf_log_dir = os.path.join(self.tf_data_dir, 'log')
        os.system('mkdir -p %r; rm -rf %r; mkdir -p %r %r %r %r' %
                  (self.tf_save_dir, self.tf_save_dir, self.tf_data_dir,
                   self.tf_config_dir, self.tf_save_dir, self.tf_log_dir))
        os.chdir(self.tf_data_dir)

        self.sys_stdio = (sys.__stdin__, sys.__stdout__, sys.__stderr__,
                          sys.stdin, sys.stdout, sys.stderr)
        self.stdout = open(os.path.join(self.tf_log_dir, '%s.log' %
                                        self.key.partition(':')[-1]),
                           'a+',
                           encoding=sys.stdout.encoding)
        sys.__stdout__ = sys.__stderr__ = sys.stdout = sys.stderr = self.stdout
        self.stdin = sys.stdin
        self.sys_path, self.sys_argv = sys.path[:], sys.argv[:]

        logger.info(
            'Accepted Tensorflow Key:%s, Job:%s, Options:%s, Scheduler:%s',
            key, tf_config, tf_option, scheduler_info)
        self.devices = dict(tensorflow_devices())
        self.future_chunk_size = DASK_READ_CHUNK_SIZE
        self.args = args
        self.kwargs = kwargs

        self.sub = Sub(self.key, worker=self.dask_worker)
        self.result = Pub(model_name, worker=self.dask_worker)
        self.exe = self.preflight()
        self.dask_worker.loop.add_callback(self.flight, self.exe)
Exemplo n.º 12
0
class TFActor(object):
    def __del__(self):

        self.recovery()

    def __str__(self):
        return "<%s %s>" % (self.__class__.__name__, self.key)

    def __init__(self,
                 key,
                 *args,
                 tf_config=None,
                 tf_option=None,
                 scheduler_info=None,
                 **kwargs):
        # here we made this thread an OWNER of this task by secede from it's ThreadPoolExecutor.
        # NOTE: `thread_pool_secede` ONLY works in NON-coroutine actor_exectutor, ref:worker.actor_execute()
        self.dask_worker = get_worker()
        self.thrid = threading.get_ident()

        thread_pool_secede(adjust=True)
        self.dask_worker.loop.add_callback(self.dask_worker.transition,
                                           thread_state.key, "long-running")

        self.key = key
        self.name = self.dask_worker.name
        self.hostname = socket.gethostname()
        self.address = self.dask_worker.address
        self.scheduler_info = scheduler_info

        model_name = self.key.partition(':')[0]
        self.model_name = model_name[:-4] if model_name.endswith(
            '.zip') else model_name
        self.tf_option = json.loads(tf_option) if isinstance(
            tf_option, str) else tf_option
        self.tf_config = json.loads(tf_config) if isinstance(
            tf_config, str) else tf_config

        self.dask_cwd = os.path.abspath(os.getcwd())
        self.tf_model_pool_dir = os.path.abspath(DASK_MODEL_POOL_DIR)
        self.tf_data_pool_dir = os.path.abspath(DASK_DATA_POOL_DIR)
        self.tf_data_dir = os.path.join(self.tf_data_pool_dir, self.model_name)
        self.tf_config_dir = os.path.join(self.tf_data_dir, 'config')
        self.tf_save_dir = os.path.join(self.tf_data_dir, 'ckpt')
        self.tf_log_dir = os.path.join(self.tf_data_dir, 'log')
        os.system('mkdir -p %r; rm -rf %r; mkdir -p %r %r %r %r' %
                  (self.tf_save_dir, self.tf_save_dir, self.tf_data_dir,
                   self.tf_config_dir, self.tf_save_dir, self.tf_log_dir))
        os.chdir(self.tf_data_dir)

        self.sys_stdio = (sys.__stdin__, sys.__stdout__, sys.__stderr__,
                          sys.stdin, sys.stdout, sys.stderr)
        self.stdout = open(os.path.join(self.tf_log_dir, '%s.log' %
                                        self.key.partition(':')[-1]),
                           'a+',
                           encoding=sys.stdout.encoding)
        sys.__stdout__ = sys.__stderr__ = sys.stdout = sys.stderr = self.stdout
        self.stdin = sys.stdin
        self.sys_path, self.sys_argv = sys.path[:], sys.argv[:]

        logger.info(
            'Accepted Tensorflow Key:%s, Job:%s, Options:%s, Scheduler:%s',
            key, tf_config, tf_option, scheduler_info)
        self.devices = dict(tensorflow_devices())
        self.future_chunk_size = DASK_READ_CHUNK_SIZE
        self.args = args
        self.kwargs = kwargs

        self.sub = Sub(self.key, worker=self.dask_worker)
        self.result = Pub(model_name, worker=self.dask_worker)
        self.exe = self.preflight()
        self.dask_worker.loop.add_callback(self.flight, self.exe)

    def device_info(self, xla=None, gpu=True):
        if xla is None:
            return gpu_filter([
                v for (x, v) in self.devices.items()
                if v['name'].find('GPU') >= 0
            ],
                              gpu_flag=gpu)
        elif xla is True:
            return gpu_filter([
                v for (x, v) in self.devices.items()
                if v['name'].find('XLA') >= 0
            ],
                              gpu_flag=gpu)
        else:
            return gpu_filter([
                v
                for (x, v) in self.devices.items() if v['name'].find('XLA') < 0
            ],
                              gpu_flag=gpu)

    def tensorflow_env(self,
                       tf_option,
                       tf_config,
                       dask_context,
                       cuda_indexes=None):
        model_entrypoint = os.path.join(self.tf_model_pool_dir,
                                        self.model_name)
        zip_ep, pkg_ep = model_entrypoint + '.zip', os.path.join(
            model_entrypoint, '__main__.py')
        if os.path.exists(pkg_ep) and os.path.isfile(pkg_ep):
            model_entrypoint = pkg_ep
        elif os.path.exists(zip_ep) and os.path.isfile(zip_ep):
            model_entrypoint = zip_ep
        else:
            raise Exception(USAGE_INFO)

        env_dict = {}

        for key in ('LANG', 'PATH', 'CUDA_HOME', 'LD_LIBRARY_PATH', 'USER',
                    'HOME', 'HOSTNAME', 'SHELL', 'TERM', 'SHLVL', 'MAIL',
                    'SSH_CONNECTION', 'SSH_TTY', 'SSH_CLIENT'):
            val = os.getenv(key)
            if val is not None:
                env_dict[key] = val

        env_dict.update(
            XLA_FLAGS='--xla_hlo_profile',
            TF_DASK_PID=str(os.getpid()),
            RF_DASK_PGRP=str(os.getpgrp()),
            TF_XLA_FLAGS=("--tf_xla_cpu_global_jit " +
                          os.environ.get("TF_XLA_FLAGS", "")),
            TF_MODEL=self.model_name,
            TF_CONTEXT=json.dumps(dask_context),
            TF_CONFIG=json.dumps(tf_config),
            TF_MODEL_POOL_DIR=self.tf_model_pool_dir,
            TF_DATA_POOL_DIR=self.tf_data_pool_dir,
            TF_MODEL_ENTRYPOINT=model_entrypoint,
            TF_CONFIG_DIR=self.tf_config_dir,
            TF_DATA_DIR=self.tf_data_dir,
            TF_LOG_DIR=self.tf_log_dir,
            TF_SAVE_DIR=self.tf_save_dir,
            PYTHONPATH=':'.join(
                [self.tf_model_pool_dir, self.tf_data_dir, self.dask_cwd]),
            PYTHONHASHSEED=str(int(DASK_PYTHONHASHSEED)),
            PYTHONIOENCODING=sys.getdefaultencoding(),
            PYTHONUNBUFFERED='True',
        )

        if cuda_indexes:  # we explicitly assign GPU indexes to use; let tensorflow aware of ONLY these indexes
            env_dict['CUDA_VISIBLE_DEVICES'] = cuda_indexes

        return env_dict

    def log(self, msg, *args, flush=True):
        self.stdout.write((msg % args) if args else msg)
        if flush:
            self.stdout.flush()

    def run_model(self, stdin, stdout, *args, **kwargs):
        import sys
        sys.stdin = stdin
        self.stdout = sys.stdout = sys.stderr = sys.__stdout__ = sys.__stderr__ = stdout

        self.log('HERE IN ASYNC SUBPROCESS: %s' % os.getpid())

        model_name = os.getenv('TF_MODEL')
        model_entry = os.getenv('TF_MODEL_ENTRYPOINT')

        if model_entry.endswith('.zip'):
            model_root, modname = model_entry, '__main__'
        elif model_entry.endswith('.py'):
            model_root, modname = os.path.dirname(
                model_entry), os.path.basename(model_entry).rsplit('.', 1)[0]

        self.log('HERE IN ASYNC MODEL START, %s, %s' % (modname, model_root))
        sys.path.insert(0, model_root)
        __import__(modname)

    def preflight(self):
        # this NODE is selected for this task
        node_name, node_port, cuda_indexes, dask_url = self.tf_config.pop(
            'dask').split(':', 3)
        job_name, task_index = self.tf_config['task']['type'], self.tf_config[
            'task']['index']
        tensorflow_addr = self.tf_config['cluster'][job_name][task_index]

        using_xla_gpu_devices = self.device_info(xla=True, gpu=True)
        using_xla_gpu_device_names = sorted(
            [x['name'] for x in using_xla_gpu_devices])

        if isinstance(self.tf_option, (str, bytes)):
            import tensorflow as tf
            tf_option = tf.compat.v1.ConfigProto.FromString(self.tf_option)
        elif self.tf_option is not None:
            tf_option = self.tf_option
        else:
            tf_option = tensorflow_options()

        dask_context = {
            'model_task':
            '%s, %s' % (self.key, ','.join(using_xla_gpu_device_names)),
            'model_addr':
            tensorflow_addr,
            'worker_addr':
            self.address,
            'schduler_addr':
            self.scheduler_info,
            'workspace':
            DASK_WORKSPACE,
            'local_dir':
            self.dask_cwd,
            'pid':
            os.getpid(),
            'thread_id':
            self.thrid,
            'code':
            0,
        }

        env_dict = self.tensorflow_env(tf_option,
                                       self.tf_config,
                                       dask_context,
                                       cuda_indexes=cuda_indexes)
        cmd = [
            sys.executable, r'-u', env_dict['TF_MODEL_ENTRYPOINT'], self.key
        ]
        fmt = 'Model Start, key:%s,\n  cmd:%s\n  dask_context:%s\n  sys.path:%s\n  tf_option:%s\n  tf_config:%s\n\n'
        self.log(fmt % (self.key, cmd, dask_context, self.sys_path, tf_option,
                        self.tf_config))

        for k, v in env_dict.items():
            if not isinstance(k, str) or not isinstance(v, str):
                self.log('Error env k:%s, v:%s\n' % (k, v))

        exe_package = partial(process.Subprocess,
                              cmd,
                              executable=DASK_PYTHON_INTERPRETER,
                              cwd=env_dict['TF_DATA_DIR'],
                              env=env_dict,
                              preexec_fn=None,
                              stdin=self.stdin,
                              stdout=self.stdout,
                              stderr=self.stdout,
                              encoding=sys.getdefaultencoding(),
                              pass_fds=(self.stdin.fileno(),
                                        self.stdout.fileno()),
                              universal_newlines=False,
                              bufsize=0,
                              restore_signals=False,
                              start_new_session=False)

        return exe_package

    def flight(self, exe_package):
        # flighting in main thread, since `SIGCHLD` MUST received in it; and then correctly call exit callback.
        self.exe = exe_package()
        self.exe.set_exit_callback(self.landed)
        msg = '\n ***** Tensorflow Task   Inited, key:%s, sub:%s, pid:%s ***** ' % (
            self.key, self.exe.pid, os.getpid())
        self.log(msg)

    def landed(self, retval=0):
        self.log("worker pub msg: %s", {self.key: retval})
        self.result.put({self.key: retval})
        ident = yield self.dask_worker.scheduler.identity()
        msg = yield self.sub._get(timeout=10)
        self.log('Tensorflow Push Message Received, sub:%s, msg:%s, ident:%s' %
                 (self.key, msg, ident))
        msg = '\n ***** Tensorflow Task Finished, key:%s, ret:%s, tid:%s, pid:%s, ***** \n\n' % (
            self.key, retval, threading.get_ident(), os.getpid())
        self.log(msg)
        self.recovery()

    def recovery(self):
        if self.sys_stdio is None:
            return

        self.log('State Recovery:%s', self.key)
        os.chdir(self.dask_cwd)
        sys.__stdin__, sys.__stdout__, sys.__stderr__, sys.stdin, sys.stdout, sys.stderr = self.sys_stdio
        sys.path, sys.argv = self.sys_path, self.sys_argv

        if self.stdin:
            if self.stdin != sys.__stdin__:
                self.stdin.close()
            else:
                self.stdin.flush()

        if self.stdout:
            if self.stdout != sys.__stdout__:
                self.stdout.close()
            else:
                self.stdout.flush()

        self.stdin = self.stdout = self.sys_stdio = self.sys_path = self.sys_argv = self.dask_worker = None
        del self.result
        del self.exe
        del self.sub
Exemplo n.º 13
0
def publish_event(dask_pub: distributed.Pub, event: BaseTaskEvent) -> None:
    dask_pub.put(event.json())
Exemplo n.º 14
0
    def run(self, n_run=None, verbose=True, n_update=None):
        if self._dask_key is None:

            def sw(message, *args, **kwargs):
                warnings._showwarning_orig(self._prefix + str(message), *args,
                                           **kwargs)
        else:
            pub = Pub(self._dask_key)

            def sw(message, category, *args, **kwargs):
                pub.put([category, self._prefix + str(message)])

        try:
            warnings.showwarning = sw
            i_iter = self._sample_trace.i_iter
            n_iter = self._sample_trace.n_iter
            n_warmup = self._sample_trace.n_warmup
            if n_run is None:
                n_run = n_iter - i_iter
            else:
                try:
                    n_run = int(n_run)
                    assert n_run > 0
                except Exception:
                    raise ValueError(self._prefix + 'invalid value for n_run.')
                if n_run > n_iter - i_iter:
                    self._sample_trace.n_iter = i_iter + n_run
                    n_iter = self._sample_trace.n_iter
            if verbose:
                if n_update is None:
                    n_update = n_run // 5
                else:
                    try:
                        n_update = int(n_update)
                        assert n_update > 0
                    except Exception:
                        warnings.warn(
                            self._prefix + 'invalid value for n_update. Using '
                            'n_run//5 for now.', RuntimeWarning)
                        n_update = n_run // 5
                t_s = time.time()
                t_i = time.time()
            for i in range(i_iter, i_iter + n_run):
                if verbose:
                    if i > i_iter and not i % n_update:
                        t_d = time.time() - t_i
                        t_i = time.time()
                        n_div = np.sum(
                            self._sample_trace.stats._diverging[-n_update:])
                        msg_0 = (self._prefix +
                                 'sampling proceeding [ {} / {} ], '
                                 'last {} samples used {:.2f} seconds'.format(
                                     i, n_iter, n_update, t_d))
                        if n_div / n_update > 0.05:
                            msg_1 = (', while divergence encountered in {} '
                                     'sample(s).'.format(n_div))
                        else:
                            msg_1 = '.'
                        if self.warmup:
                            msg_2 = ' (warmup)'
                        else:
                            msg_2 = ''
                        if self._dask_key is None:
                            if self.has_lock:
                                self.process_lock.acquire()
                            print(msg_0 + msg_1 + msg_2)
                            if self.has_lock:
                                self.process_lock.release()
                        else:
                            pub.put(
                                ['SamplingProceeding', msg_0 + msg_1 + msg_2])
                self.warmup = bool(i < n_warmup)
                self.astep()
            if verbose:
                t_f = time.time()
                msg = (self._prefix + 'sampling finished [ {} / {} ], '
                       'obtained {} samples in {:.2f} seconds.'.format(
                           n_iter, n_iter, n_run, t_f - t_s))
                if self._dask_key is None:
                    if self.has_lock:
                        self.process_lock.acquire()
                    print(msg)
                    if self.has_lock:
                        self.process_lock.release()
                else:
                    pub.put(['SamplingFinished', msg])
            return self._sample_trace
        finally:
            warnings.showwarning = warnings._showwarning_orig
Exemplo n.º 15
0
    def run(self,
            n_iter=3000,
            n_warmup=1000,
            verbose=True,
            n_update=None,
            return_copy=True):
        n_iter = int(n_iter)
        n_warmup = int(n_warmup)
        if self._dask_key is not None:
            pub = Pub(self._dask_key)

            def sw(message, category, *args, **kwargs):
                pub.put([category, self._prefix + str(message)])

            warnings.showwarning = sw
        try:
            if not n_iter >= 0:
                raise ValueError(self._prefix + 'n_iter cannot be negative.')
            if n_warmup > n_iter:
                warnings.warn(
                    self._prefix + 'n_warmup is larger than n_iter. Setting '
                    'n_warmup = n_iter for now.', RuntimeWarning)
                n_warmup = n_iter
            if self._trace.n_iter > self._trace.n_warmup and n_warmup > 0:
                warnings.warn(
                    self._prefix + 'self.trace indicates that warmup has '
                    'completed, so n_warmup will be set to 0.', RuntimeWarning)
                n_warmup = 0
            i_iter = self._trace.i_iter
            self._trace._n_iter += n_iter
            self._trace._n_warmup += n_warmup
            n_iter = self._trace._n_iter
            n_warmup = self._trace._n_warmup
            if verbose:
                n_run = n_iter - i_iter
                if n_update is None:
                    n_update = n_run // 5
                else:
                    n_update = int(n_update)
                    if n_update <= 0:
                        warnings.warn(
                            self._prefix + 'invalid n_update value. Using '
                            'n_run // 5 for now.', RuntimeWarning)
                        n_update = n_run // 5
                t_s = time.time()
                t_i = time.time()
            for i in range(i_iter, n_iter):
                if verbose:
                    if i > i_iter and not i % n_update:
                        t_d = time.time() - t_i
                        t_i = time.time()
                        n_div = np.sum(
                            self._trace._stats._diverging[-n_update:])
                        msg_0 = (self._prefix +
                                 'sampling proceeding [ {} / {} ], '
                                 'last {} samples used {:.2f} seconds'.format(
                                     i, n_iter, n_update, t_d))
                        if n_div / n_update > 0.05:
                            msg_1 = (', while divergence encountered in {} '
                                     'sample(s).'.format(n_div))
                        else:
                            msg_1 = '.'
                        if self.warmup:
                            msg_2 = ' (warmup)'
                        else:
                            msg_2 = ''
                        if self._dask_key is None:
                            print(msg_0 + msg_1 + msg_2)
                        else:
                            pub.put(
                                ['SamplingProceeding', msg_0 + msg_1 + msg_2])
                self.warmup = bool(i < n_warmup)
                self.astep()
            if verbose:
                t_f = time.time()
                msg = (self._prefix + 'sampling finished [ {} / {} ], '
                       'obtained {} samples in {:.2f} seconds.'.format(
                           n_iter, n_iter, n_run, t_f - t_s))
                if self._dask_key is None:
                    print(msg)
                else:
                    pub.put(['SamplingFinished', msg])
            return self.trace if return_copy else self._trace
        except:
            if self._dask_key is not None:
                pub.put(['Error', self._chain_id])
            raise
        finally:
            warnings.showwarning = warnings._showwarning_orig
Exemplo n.º 16
0
 def f(x):
     pub = Pub('a')
     pub.put(x)
Exemplo n.º 17
0
 def f(x):
     pub = Pub("a")
     pub.put(x)
Exemplo n.º 18
0
def tensorflow_scheduler(global_future,
                         model_name,
                         client=None,
                         tf_option=None,
                         tf_port=None,
                         **tf_cluster_spec):
    scheduler_info = yield client.scheduler.identity()
    cuda_free_map = yield client.run(cuda_free_indexes)
    tf_configs = tensorflow_gen_config(free_node_name_map=cuda_free_map,
                                       **tf_cluster_spec)

    logger.info('Model Schedule %s: \n  tf_configs:%s\n\n', model_name,
                tf_configs)

    tf_option = tf_option if isinstance(tf_option, (str, bytes)) else (
        tf_option.SerializeToString() if tf_option else tf_option)

    chief_configs, ps_configs, other_configs = [], [], []
    for tf_config in tf_configs:
        task_type = tf_config['task']['type']
        task_index = tf_config['task']['index']
        if task_type in ('chief', 'master'):
            chief_configs.append(tf_config)
        elif task_type in ('ps', ):
            ps_configs.append(tf_config)
        else:
            other_configs.append(tf_config)

    s_time = time.time()
    dt = datetime.datetime.now()

    chief_configs.sort(key=lambda cfg: cfg['task']['index'])
    ps_configs.sort(key=lambda cfg: cfg['task']['index'])
    other_configs.sort(key=lambda cfg: cfg['task']['index'])

    client.loop.set_default_executor(
        ThreadPoolExecutor(max_workers=len(tf_configs)))

    result_future = Future()
    result_future.tf_configs = tf_configs
    result_future.tf_option = tf_option
    result_future.cuda_map = cuda_free_map

    chief_future = Future()
    client.loop.add_callback(startup_actors, scheduler_info, client,
                             model_name, tf_option, tf_configs, chief_future)
    chief_actors = yield chief_future

    sorted_task_keys = list(
        sorted(chief_actors.keys(), key=lambda x: dask_sork_key(x)))

    sub = Sub(model_name, client=client)
    pubs = {k: Pub(model_name, client=client) for k in sorted_task_keys}
    scheduler_info = yield client.scheduler.identity(
    )  # data flush sync between this client and scheduler

    def chief_finish(task_key, actor, fu):
        value = fu.result()
        logger.info('Tensorflow Finished[%s/%s], key:%s, val:%s',
                    len(chief_actors), len(tf_configs), task_key, value)
        chief_actors[task_key] = actor
        if len(chief_actors) == len(tf_configs):
            logger.info('Tensorflow Cluster All Finished: %s',
                        chief_actors.keys())

    # Chief First.
    msgs = {}
    chief_key_actor = sorted_task_keys[0]

    while (len(msgs) + 1) < len(chief_actors):
        msg = yield sub._get()
        logger.info('Sub Rcv %s:%s', type(msg), msg)
        msgs.update(msg)

    import pdb
    pdb.set_trace()
    #    A = yield chief_actor.get_result()
    assert chief_key_actor in msgs, 'Tensorflow Chief Task Required: %s' % chief_key_actor
    time.sleep(1)
    future = yield model_cleanup(client, model_name)
    import pdb
    pdb.set_trace()
    logger.info("Tensorflow Task clean, %s", chief_actors)
    global_future.set_result(chief_actors)