Example #1
0
def test_timeouts(c, s, a, b):
    sub = Sub('a', client=c, worker=None)
    start = time()
    with pytest.raises(TimeoutError):
        yield sub.get(timeout=0.1)
    stop = time()
    assert stop - start < 1
Example #2
0
def test_timeouts(c, s, a, b):
    sub = Sub("a", client=c, worker=None)
    start = time()
    with pytest.raises(TimeoutError):
        yield sub.get(timeout=0.1)
    stop = time()
    assert stop - start < 1
Example #3
0
def test_client_worker(c, s, a, b):
    sub = Sub("a", client=c, worker=None)

    def f(x):
        pub = Pub("a")
        pub.put(x)

    futures = c.map(f, range(10))
    yield wait(futures)

    L = []
    for i in range(10):
        result = yield sub.get()
        L.append(result)

    assert set(L) == set(range(10))

    sps = s.extensions["pubsub"]
    aps = a.extensions["pubsub"]
    bps = b.extensions["pubsub"]

    start = time()
    while (sps.publishers["a"] or sps.subscribers["a"] or aps.publishers["a"]
           or bps.publishers["a"] or len(sps.client_subscribers["a"]) != 1):
        yield gen.sleep(0.01)
        assert time() < start + 3

    del sub

    start = time()
    while (sps.client_subscribers or any(aps.publish_to_scheduler.values())
           or any(bps.publish_to_scheduler.values())):
        yield gen.sleep(0.01)
        assert time() < start + 3
Example #4
0
    def fake_sidecar_fct(
        docker_auth: DockerBasicAuth,
        service_key: str,
        service_version: str,
        input_data: TaskInputData,
        output_data_keys: TaskOutputDataSchema,
        log_file_url: AnyUrl,
        command: List[str],
        expected_annotations: Dict[str, Any],
    ) -> TaskOutputData:
        sub = Sub(TaskCancelEvent.topic_name())
        # get the task data
        worker = get_worker()
        task = worker.tasks.get(worker.get_current_task())
        assert task is not None
        print(f"--> task {task=} started")
        assert task.annotations == expected_annotations
        # sleep a bit in case someone is aborting us
        print("--> waiting for task to be aborted...")
        for msg in sub:
            assert msg
            print(f"--> received cancellation msg: {msg=}")
            cancel_event = TaskCancelEvent.parse_raw(msg)  # type: ignore
            assert cancel_event
            if cancel_event.job_id == task.key:
                print("--> raising cancellation error now")
                raise asyncio.CancelledError("task cancelled")

        return TaskOutputData.parse_obj({"some_output_key": 123})
Example #5
0
async def test_repr(c, s, a, b):
    pub = Pub("my-topic")
    sub = Sub("my-topic")
    assert "my-topic" in str(pub)
    assert "Pub" in str(pub)
    assert "my-topic" in str(sub)
    assert "Sub" in str(sub)
Example #6
0
def test_client(c, s):
    with pytest.raises(Exception):
        get_worker()
    sub = Sub("a")
    pub = Pub("a")

    sps = s.extensions["pubsub"]
    cps = c.extensions["pubsub"]

    start = time()
    while not set(sps.client_subscribers["a"]) == {c.id}:
        yield gen.sleep(0.01)
        assert time() < start + 3

    pub.put(123)

    result = yield sub.__anext__()
    assert result == 123
Example #7
0
def test_client(c, s):
    with pytest.raises(Exception):
        get_worker()
    sub = Sub('a')
    pub = Pub('a')

    sps = s.extensions['pubsub']
    cps = c.extensions['pubsub']

    start = time()
    while not set(sps.client_subscribers['a']) == {c.id}:
        yield gen.sleep(0.01)
        assert time() < start + 3

    pub.put(123)

    result = yield sub.__anext__()
    assert result == 123
Example #8
0
async def test_timeouts(c, s, a, b):
    sub = Sub("a", client=c, worker=None)
    start = time()
    with pytest.raises(TimeoutError):
        await sub.get(timeout="100ms")
    stop = time()
    assert stop - start < 1
    with pytest.raises(TimeoutError):
        await sub.get(timeout=timedelta(milliseconds=10))
Example #9
0
def test_client_worker(c, s, a, b):
    sub = Sub('a', client=c, worker=None)

    def f(x):
        pub = Pub('a')
        pub.put(x)

    futures = c.map(f, range(10))
    yield wait(futures)

    L = []
    for i in range(10):
        result = yield sub.get()
        L.append(result)

    assert set(L) == set(range(10))

    sps = s.extensions['pubsub']
    aps = a.extensions['pubsub']
    bps = b.extensions['pubsub']

    start = time()
    while (sps.publishers['a'] or
           sps.subscribers['a'] or
           aps.publishers['a'] or
           bps.publishers['a'] or
           len(sps.client_subscribers['a']) != 1):
        yield gen.sleep(0.01)
        assert time() < start + 3

    del sub

    start = time()
    while (sps.client_subscribers or
           any(aps.publish_to_scheduler.values()) or
           any(bps.publish_to_scheduler.values())):
        yield gen.sleep(0.01)
        assert time() < start + 3
Example #10
0
def is_current_task_aborted(sub: distributed.Sub) -> bool:
    task: Optional[TaskState] = _get_current_task_state()
    logger.debug("found following TaskState: %s", task)
    if task is None:
        # the task was removed from the list of tasks this worker should work on, meaning it is aborted
        # NOTE: this does not work in distributed mode, hence we need to use Variables,or PubSub
        return True

    with suppress(asyncio.TimeoutError):
        msg = sub.get(timeout="100ms")
        if msg:
            cancel_event = TaskCancelEvent.parse_raw(msg)  # type: ignore
            return bool(cancel_event.job_id == task.key)
    return False
Example #11
0
    def pingpong(a, b, start=False, n=1000, msg=1):
        sub = Sub(a)
        pub = Pub(b)

        while not pub.subscribers:
            sleep(0.01)

        if start:
            pub.put(msg)  # other sub may not have started yet

        for i in range(n):
            msg = next(sub)
            pub.put(msg)
            # if i % 100 == 0:
            #     print(a, b, i)
        return n
Example #12
0
def tensorflow_scheduler(global_future,
                         model_name,
                         client=None,
                         tf_option=None,
                         tf_port=None,
                         **tf_cluster_spec):
    scheduler_info = yield client.scheduler.identity()
    cuda_free_map = yield client.run(cuda_free_indexes)
    tf_configs = tensorflow_gen_config(free_node_name_map=cuda_free_map,
                                       **tf_cluster_spec)

    logger.info('Model Schedule %s: \n  tf_configs:%s\n\n', model_name,
                tf_configs)

    tf_option = tf_option if isinstance(tf_option, (str, bytes)) else (
        tf_option.SerializeToString() if tf_option else tf_option)

    chief_configs, ps_configs, other_configs = [], [], []
    for tf_config in tf_configs:
        task_type = tf_config['task']['type']
        task_index = tf_config['task']['index']
        if task_type in ('chief', 'master'):
            chief_configs.append(tf_config)
        elif task_type in ('ps', ):
            ps_configs.append(tf_config)
        else:
            other_configs.append(tf_config)

    s_time = time.time()
    dt = datetime.datetime.now()

    chief_configs.sort(key=lambda cfg: cfg['task']['index'])
    ps_configs.sort(key=lambda cfg: cfg['task']['index'])
    other_configs.sort(key=lambda cfg: cfg['task']['index'])

    client.loop.set_default_executor(
        ThreadPoolExecutor(max_workers=len(tf_configs)))

    result_future = Future()
    result_future.tf_configs = tf_configs
    result_future.tf_option = tf_option
    result_future.cuda_map = cuda_free_map

    chief_future = Future()
    client.loop.add_callback(startup_actors, scheduler_info, client,
                             model_name, tf_option, tf_configs, chief_future)
    chief_actors = yield chief_future

    sorted_task_keys = list(
        sorted(chief_actors.keys(), key=lambda x: dask_sork_key(x)))

    sub = Sub(model_name, client=client)
    pubs = {k: Pub(model_name, client=client) for k in sorted_task_keys}
    scheduler_info = yield client.scheduler.identity(
    )  # data flush sync between this client and scheduler

    def chief_finish(task_key, actor, fu):
        value = fu.result()
        logger.info('Tensorflow Finished[%s/%s], key:%s, val:%s',
                    len(chief_actors), len(tf_configs), task_key, value)
        chief_actors[task_key] = actor
        if len(chief_actors) == len(tf_configs):
            logger.info('Tensorflow Cluster All Finished: %s',
                        chief_actors.keys())

    # Chief First.
    msgs = {}
    chief_key_actor = sorted_task_keys[0]

    while (len(msgs) + 1) < len(chief_actors):
        msg = yield sub._get()
        logger.info('Sub Rcv %s:%s', type(msg), msg)
        msgs.update(msg)

    import pdb
    pdb.set_trace()
    #    A = yield chief_actor.get_result()
    assert chief_key_actor in msgs, 'Tensorflow Chief Task Required: %s' % chief_key_actor
    time.sleep(1)
    future = yield model_cleanup(client, model_name)
    import pdb
    pdb.set_trace()
    logger.info("Tensorflow Task clean, %s", chief_actors)
    global_future.set_result(chief_actors)
Example #13
0
def sample(density,
           client=None,
           n_chain=4,
           n_iter=None,
           n_warmup=None,
           trace=None,
           random_state=None,
           x_0=None,
           verbose=True,
           return_trace=False,
           density_options={},
           sampler_options={},
           sampler='NUTS'):
    # DEVELOPMENT NOTES
    # if use_surrogate is not specified in density_options
    # x_0 is interpreted as in original space and will be transformed
    # otherwise, x_0 is understood as in the specified space
    if not isinstance(density, (Density, DensityLite)):
        raise ValueError('density should be a Density or DensityLite.')
    if isinstance(trace, Trace):
        trace = [trace]
    if (hasattr(trace, '__iter__') and len(trace) > 0
            and all(isinstance(t, Trace) for t in trace)):
        n_chain = len(trace)
    else:
        try:
            n_chain = int(n_chain)
            assert n_chain > 0
        except:
            raise ValueError('invalid value for n_chain')
        trace = [None for i in range(n_chain)]

    if n_iter is None:
        n_iter = 3000 if (trace[0] is None) else 1000
    else:
        try:
            n_iter = int(n_iter)
            assert n_iter > 0
        except:
            raise ValueError('invalid value for n_iter.')
    if n_warmup is None:
        n_warmup = 1000 if (trace[0] is None) else 0
    else:
        try:
            n_warmup = int(n_warmup)
            assert n_warmup > 0
        except:
            raise ValueError('invalid value for n_warmup.')

    try:
        density_options = dict(density_options).copy()
    except:
        raise ValueError('density_options should be a dict.')
    if not 'use_surrogate' in density_options:
        density_options['use_surrogate'] = True
    if not 'original_space' in density_options:
        density_options['original_space'] = False
        _transform_x = True
    else:
        _transform_x = False

    if trace[0] is None:
        if hasattr(random_state, '__iter__'):
            random_state = [bfrandom.check_state(rs) for rs in random_state]
            if len(random_state) < n_chain:
                raise ValueError('you did not give me enough random_state(s).')
        else:
            random_state = bfrandom.check_state(random_state)
            random_state = bfrandom.split_state(random_state, n_chain)
        if x_0 is None:
            dim = density.input_size
            x_0 = bfrandom.multivariate_normal(np.zeros(dim), np.eye(dim),
                                               n_chain)
        else:
            x_0 = np.atleast_2d(x_0)
            if x_0.shape[0] < n_chain:
                raise ValueError('you did not give me enough x_0(s).')
            x_0 = x_0[:n_chain, :]
            if _transform_x:
                x_0 = density.from_original(x_0)
    else:
        random_state = [None for i in range(n_chain)]
        x_0 = [None for i in range(n_chain)]

    try:
        client, _new_client = check_client(client)
        # dask_key = bfrandom.string()
        dask_key = 'BayesFast-' + client.id
        sub = Sub(dask_key)
        finished = 0

        if sampler == 'NUTS':

            def nuts_worker(i):
                with threadpool_limits(1):

                    def logp_and_grad(x):
                        return density.logp_and_grad(x, **density_options)

                    nuts = NUTS(logp_and_grad=logp_and_grad,
                                trace=trace[i],
                                dask_key=dask_key,
                                chain_id=i,
                                random_state=random_state[i],
                                x_0=x_0[i],
                                **sampler_options)
                    t = nuts.run(n_iter, n_warmup, verbose)
                return t  # if return_trace else t.get()

            foo = client.map(nuts_worker, range(n_chain))
            for msg in sub:
                if not hasattr(msg, '__iter__'):
                    warnings.warn('unexpected message: {}.'.format(msg),
                                  RuntimeWarning)
                elif msg[0] == 'Error':
                    break
                elif isclass(msg[0]) and issubclass(msg[0], Warning):
                    warnings.warn(msg[1], msg[0])
                elif msg[0] == 'SamplingProceeding':
                    print(msg[1])
                elif msg[0] == 'SamplingFinished':
                    print(msg[1])
                    finished += 1
                else:
                    warnings.warn('unexpected message: {}.'.format(msg),
                                  RuntimeWarning)
                if finished == n_chain:
                    break
            tt = client.gather(foo)
            if _transform_x:
                xx = np.array([density.to_original(t.get()) for t in tt])
            else:
                xx = np.array([t.get() for t in tt])
            return (xx, tt) if return_trace else xx

        elif sampler == 'HMC':
            raise NotImplementedError

        elif sampler == 'EnsembleSampler':
            raise NotImplementedError

        else:
            raise ValueError(
                'Sorry I do not know how to do {}.'.format(sampler))
    finally:
        if _new_client:
            client.cluster.close()
            client.close()
Example #14
0
 def f(_):
     sub = Sub("a")
     return list(toolz.take(5, sub))
Example #15
0
def sample(density,
           sample_trace=None,
           sampler='NUTS',
           n_run=None,
           parallel_backend=None,
           verbose=True):
    if not isinstance(density, (Density, DensityLite)):
        raise ValueError('density should be a Density or DensityLite.')

    if isinstance(sample_trace, NTrace):
        sampler = 'NUTS'
    elif isinstance(sample_trace, HTrace):
        sampler = 'HMC'
    elif isinstance(sample_trace, TNTrace):
        sampler = 'TNUTS'
    elif isinstance(sample_trace, THTrace):
        sampler = 'THMC'
    elif isinstance(sample_trace, ETrace):
        raise NotImplementedError
    elif sample_trace is None or isinstance(sample_trace, dict):
        sample_trace = {} if (sample_trace is None) else sample_trace
        if sampler == 'NUTS':
            sample_trace = NTrace(**sample_trace)
        elif sampler == 'HMC':
            sample_trace = HTrace(**sample_trace)
        elif sampler == 'TNUTS':
            sample_trace = TNTrace(**sample_trace)
        elif sampler == 'THMC':
            sample_trace = THTrace(**sample_trace)
        elif sampler == 'Ensemble':
            raise NotImplementedError
        else:
            raise ValueError('unexpected value for sampler.')
    elif isinstance(sample_trace, TraceTuple):
        sampler = sample_trace.sampler
        if any(sampler == _ for _ in ('NUTS', 'HMC', 'TNUTS', 'THMC')):
            pass
        elif sampler == 'Ensemble':
            raise NotImplementedError
        else:
            raise ValueError('unexpected value for sample_trace.sampler.')
    else:
        raise ValueError('unexpected value for sample_trace.')

    if isinstance(sample_trace, SampleTrace):
        if sample_trace.random_generator is None:
            sample_trace.random_generator = get_generator()
            get_generator().normal()
        if sample_trace.x_0 is None:
            dim = density.input_size
            if dim is None:
                raise RuntimeError('Neither SampleTrace.x_0 nor Density'
                                   '/DensityLite.input_size is defined.')
            sample_trace._x_0 = multivariate_normal(np.zeros(dim), np.eye(dim),
                                                    sample_trace.n_chain)
            sample_trace._x_0_transformed = True
        elif not sample_trace.x_0_transformed:
            sample_trace._x_0 = density.from_original(sample_trace._x_0)
            sample_trace._x_0_transformed = True

    if parallel_backend is None:
        parallel_backend = get_backend()
    else:
        parallel_backend = ParallelBackend(parallel_backend)

    if parallel_backend.kind == 'multiprocess':
        use_dask = False
        dask_key = None
        process_lock = Manager().Lock()
    elif parallel_backend.kind == 'ray':
        use_dask = False
        dask_key = None
        process_lock = None
    elif parallel_backend.kind == 'dask':
        if not HAS_DASK:
            raise RuntimeError(
                'you want me to use dask but have not installed '
                'it.')
        use_dask = True
        dask_key = 'BayesFast-' + parallel_backend.backend.id
        process_lock = None
        sub = Sub(dask_key)
        finished = 0
    elif parallel_backend.kind == 'sharedmem':
        use_dask = False
        dask_key = None
        process_lock = None
    elif parallel_backend.kind == 'loky':
        use_dask = False
        dask_key = None
        process_lock = None
    # elif parallel_backend.kind == 'serial':
    #     use_dask = False
    #     dask_key = None
    #     process_lock = None
    else:
        raise RuntimeError('unexpected value for parallel_backend.kind.')

    def nested_helper(sample_trace, i):
        """Without this, there will be an UnboundLocalError."""
        if isinstance(sample_trace, SampleTrace):
            sample_trace._init_chain(i)
        elif isinstance(sample_trace, TraceTuple):
            sample_trace = sample_trace.sample_traces[i]
        else:
            raise RuntimeError('unexpected type for sample_trace.')
        return sample_trace

    def _sampler_worker(i, sampler_class):
        try:
            with threadpool_limits(1):
                _sample_trace = nested_helper(sample_trace, i)

                def logp_and_grad(x):
                    return density.logp_and_grad(x, original_space=False)

                _sampler = sampler_class(logp_and_grad=logp_and_grad,
                                         sample_trace=_sample_trace,
                                         dask_key=dask_key,
                                         process_lock=process_lock)
                t = _sampler.run(n_run, verbose)
                t._samples_original = density.to_original(t.samples)
                t._logp_original = density.to_original_density(
                    t.logp, x_trans=t.samples)
            return t
        except Exception:
            if use_dask:
                pub = Pub(dask_key)
                pub.put(['Error', i])
            raise

    with parallel_backend:
        if any(sampler == _ for _ in ('NUTS', 'HMC', 'TNUTS', 'THMC')):
            if use_dask:
                foo = parallel_backend.map_async(_sampler_worker,
                                                 range(sample_trace.n_chain),
                                                 [eval(sampler)] *
                                                 sample_trace.n_chain)
                for msg in sub:
                    if not hasattr(msg, '__iter__'):
                        warnings.warn('unexpected message: {}.'.format(msg),
                                      RuntimeWarning)
                    elif msg[0] == 'Error':
                        break
                    elif isclass(msg[0]) and issubclass(msg[0], Warning):
                        warnings.warn(msg[1], msg[0])
                    elif msg[0] == 'SamplingProceeding':
                        print(msg[1])
                    elif msg[0] == 'SamplingFinished':
                        print(msg[1])
                        finished += 1
                    else:
                        warnings.warn('unexpected message: {}.'.format(msg),
                                      RuntimeWarning)
                    if finished == sample_trace.n_chain:
                        break
                tt = parallel_backend.gather(foo)
            else:
                tt = parallel_backend.map(_sampler_worker,
                                          range(sample_trace.n_chain),
                                          [eval(sampler)] *
                                          sample_trace.n_chain)
            return TraceTuple(tt)

        elif sampler == 'Ensemble':
            raise NotImplementedError

        else:
            raise RuntimeError('unexpected value for sampler.')
Example #16
0
    def __init__(self,
                 key,
                 *args,
                 tf_config=None,
                 tf_option=None,
                 scheduler_info=None,
                 **kwargs):
        # here we made this thread an OWNER of this task by secede from it's ThreadPoolExecutor.
        # NOTE: `thread_pool_secede` ONLY works in NON-coroutine actor_exectutor, ref:worker.actor_execute()
        self.dask_worker = get_worker()
        self.thrid = threading.get_ident()

        thread_pool_secede(adjust=True)
        self.dask_worker.loop.add_callback(self.dask_worker.transition,
                                           thread_state.key, "long-running")

        self.key = key
        self.name = self.dask_worker.name
        self.hostname = socket.gethostname()
        self.address = self.dask_worker.address
        self.scheduler_info = scheduler_info

        model_name = self.key.partition(':')[0]
        self.model_name = model_name[:-4] if model_name.endswith(
            '.zip') else model_name
        self.tf_option = json.loads(tf_option) if isinstance(
            tf_option, str) else tf_option
        self.tf_config = json.loads(tf_config) if isinstance(
            tf_config, str) else tf_config

        self.dask_cwd = os.path.abspath(os.getcwd())
        self.tf_model_pool_dir = os.path.abspath(DASK_MODEL_POOL_DIR)
        self.tf_data_pool_dir = os.path.abspath(DASK_DATA_POOL_DIR)
        self.tf_data_dir = os.path.join(self.tf_data_pool_dir, self.model_name)
        self.tf_config_dir = os.path.join(self.tf_data_dir, 'config')
        self.tf_save_dir = os.path.join(self.tf_data_dir, 'ckpt')
        self.tf_log_dir = os.path.join(self.tf_data_dir, 'log')
        os.system('mkdir -p %r; rm -rf %r; mkdir -p %r %r %r %r' %
                  (self.tf_save_dir, self.tf_save_dir, self.tf_data_dir,
                   self.tf_config_dir, self.tf_save_dir, self.tf_log_dir))
        os.chdir(self.tf_data_dir)

        self.sys_stdio = (sys.__stdin__, sys.__stdout__, sys.__stderr__,
                          sys.stdin, sys.stdout, sys.stderr)
        self.stdout = open(os.path.join(self.tf_log_dir, '%s.log' %
                                        self.key.partition(':')[-1]),
                           'a+',
                           encoding=sys.stdout.encoding)
        sys.__stdout__ = sys.__stderr__ = sys.stdout = sys.stderr = self.stdout
        self.stdin = sys.stdin
        self.sys_path, self.sys_argv = sys.path[:], sys.argv[:]

        logger.info(
            'Accepted Tensorflow Key:%s, Job:%s, Options:%s, Scheduler:%s',
            key, tf_config, tf_option, scheduler_info)
        self.devices = dict(tensorflow_devices())
        self.future_chunk_size = DASK_READ_CHUNK_SIZE
        self.args = args
        self.kwargs = kwargs

        self.sub = Sub(self.key, worker=self.dask_worker)
        self.result = Pub(model_name, worker=self.dask_worker)
        self.exe = self.preflight()
        self.dask_worker.loop.add_callback(self.flight, self.exe)
Example #17
0
class TFActor(object):
    def __del__(self):

        self.recovery()

    def __str__(self):
        return "<%s %s>" % (self.__class__.__name__, self.key)

    def __init__(self,
                 key,
                 *args,
                 tf_config=None,
                 tf_option=None,
                 scheduler_info=None,
                 **kwargs):
        # here we made this thread an OWNER of this task by secede from it's ThreadPoolExecutor.
        # NOTE: `thread_pool_secede` ONLY works in NON-coroutine actor_exectutor, ref:worker.actor_execute()
        self.dask_worker = get_worker()
        self.thrid = threading.get_ident()

        thread_pool_secede(adjust=True)
        self.dask_worker.loop.add_callback(self.dask_worker.transition,
                                           thread_state.key, "long-running")

        self.key = key
        self.name = self.dask_worker.name
        self.hostname = socket.gethostname()
        self.address = self.dask_worker.address
        self.scheduler_info = scheduler_info

        model_name = self.key.partition(':')[0]
        self.model_name = model_name[:-4] if model_name.endswith(
            '.zip') else model_name
        self.tf_option = json.loads(tf_option) if isinstance(
            tf_option, str) else tf_option
        self.tf_config = json.loads(tf_config) if isinstance(
            tf_config, str) else tf_config

        self.dask_cwd = os.path.abspath(os.getcwd())
        self.tf_model_pool_dir = os.path.abspath(DASK_MODEL_POOL_DIR)
        self.tf_data_pool_dir = os.path.abspath(DASK_DATA_POOL_DIR)
        self.tf_data_dir = os.path.join(self.tf_data_pool_dir, self.model_name)
        self.tf_config_dir = os.path.join(self.tf_data_dir, 'config')
        self.tf_save_dir = os.path.join(self.tf_data_dir, 'ckpt')
        self.tf_log_dir = os.path.join(self.tf_data_dir, 'log')
        os.system('mkdir -p %r; rm -rf %r; mkdir -p %r %r %r %r' %
                  (self.tf_save_dir, self.tf_save_dir, self.tf_data_dir,
                   self.tf_config_dir, self.tf_save_dir, self.tf_log_dir))
        os.chdir(self.tf_data_dir)

        self.sys_stdio = (sys.__stdin__, sys.__stdout__, sys.__stderr__,
                          sys.stdin, sys.stdout, sys.stderr)
        self.stdout = open(os.path.join(self.tf_log_dir, '%s.log' %
                                        self.key.partition(':')[-1]),
                           'a+',
                           encoding=sys.stdout.encoding)
        sys.__stdout__ = sys.__stderr__ = sys.stdout = sys.stderr = self.stdout
        self.stdin = sys.stdin
        self.sys_path, self.sys_argv = sys.path[:], sys.argv[:]

        logger.info(
            'Accepted Tensorflow Key:%s, Job:%s, Options:%s, Scheduler:%s',
            key, tf_config, tf_option, scheduler_info)
        self.devices = dict(tensorflow_devices())
        self.future_chunk_size = DASK_READ_CHUNK_SIZE
        self.args = args
        self.kwargs = kwargs

        self.sub = Sub(self.key, worker=self.dask_worker)
        self.result = Pub(model_name, worker=self.dask_worker)
        self.exe = self.preflight()
        self.dask_worker.loop.add_callback(self.flight, self.exe)

    def device_info(self, xla=None, gpu=True):
        if xla is None:
            return gpu_filter([
                v for (x, v) in self.devices.items()
                if v['name'].find('GPU') >= 0
            ],
                              gpu_flag=gpu)
        elif xla is True:
            return gpu_filter([
                v for (x, v) in self.devices.items()
                if v['name'].find('XLA') >= 0
            ],
                              gpu_flag=gpu)
        else:
            return gpu_filter([
                v
                for (x, v) in self.devices.items() if v['name'].find('XLA') < 0
            ],
                              gpu_flag=gpu)

    def tensorflow_env(self,
                       tf_option,
                       tf_config,
                       dask_context,
                       cuda_indexes=None):
        model_entrypoint = os.path.join(self.tf_model_pool_dir,
                                        self.model_name)
        zip_ep, pkg_ep = model_entrypoint + '.zip', os.path.join(
            model_entrypoint, '__main__.py')
        if os.path.exists(pkg_ep) and os.path.isfile(pkg_ep):
            model_entrypoint = pkg_ep
        elif os.path.exists(zip_ep) and os.path.isfile(zip_ep):
            model_entrypoint = zip_ep
        else:
            raise Exception(USAGE_INFO)

        env_dict = {}

        for key in ('LANG', 'PATH', 'CUDA_HOME', 'LD_LIBRARY_PATH', 'USER',
                    'HOME', 'HOSTNAME', 'SHELL', 'TERM', 'SHLVL', 'MAIL',
                    'SSH_CONNECTION', 'SSH_TTY', 'SSH_CLIENT'):
            val = os.getenv(key)
            if val is not None:
                env_dict[key] = val

        env_dict.update(
            XLA_FLAGS='--xla_hlo_profile',
            TF_DASK_PID=str(os.getpid()),
            RF_DASK_PGRP=str(os.getpgrp()),
            TF_XLA_FLAGS=("--tf_xla_cpu_global_jit " +
                          os.environ.get("TF_XLA_FLAGS", "")),
            TF_MODEL=self.model_name,
            TF_CONTEXT=json.dumps(dask_context),
            TF_CONFIG=json.dumps(tf_config),
            TF_MODEL_POOL_DIR=self.tf_model_pool_dir,
            TF_DATA_POOL_DIR=self.tf_data_pool_dir,
            TF_MODEL_ENTRYPOINT=model_entrypoint,
            TF_CONFIG_DIR=self.tf_config_dir,
            TF_DATA_DIR=self.tf_data_dir,
            TF_LOG_DIR=self.tf_log_dir,
            TF_SAVE_DIR=self.tf_save_dir,
            PYTHONPATH=':'.join(
                [self.tf_model_pool_dir, self.tf_data_dir, self.dask_cwd]),
            PYTHONHASHSEED=str(int(DASK_PYTHONHASHSEED)),
            PYTHONIOENCODING=sys.getdefaultencoding(),
            PYTHONUNBUFFERED='True',
        )

        if cuda_indexes:  # we explicitly assign GPU indexes to use; let tensorflow aware of ONLY these indexes
            env_dict['CUDA_VISIBLE_DEVICES'] = cuda_indexes

        return env_dict

    def log(self, msg, *args, flush=True):
        self.stdout.write((msg % args) if args else msg)
        if flush:
            self.stdout.flush()

    def run_model(self, stdin, stdout, *args, **kwargs):
        import sys
        sys.stdin = stdin
        self.stdout = sys.stdout = sys.stderr = sys.__stdout__ = sys.__stderr__ = stdout

        self.log('HERE IN ASYNC SUBPROCESS: %s' % os.getpid())

        model_name = os.getenv('TF_MODEL')
        model_entry = os.getenv('TF_MODEL_ENTRYPOINT')

        if model_entry.endswith('.zip'):
            model_root, modname = model_entry, '__main__'
        elif model_entry.endswith('.py'):
            model_root, modname = os.path.dirname(
                model_entry), os.path.basename(model_entry).rsplit('.', 1)[0]

        self.log('HERE IN ASYNC MODEL START, %s, %s' % (modname, model_root))
        sys.path.insert(0, model_root)
        __import__(modname)

    def preflight(self):
        # this NODE is selected for this task
        node_name, node_port, cuda_indexes, dask_url = self.tf_config.pop(
            'dask').split(':', 3)
        job_name, task_index = self.tf_config['task']['type'], self.tf_config[
            'task']['index']
        tensorflow_addr = self.tf_config['cluster'][job_name][task_index]

        using_xla_gpu_devices = self.device_info(xla=True, gpu=True)
        using_xla_gpu_device_names = sorted(
            [x['name'] for x in using_xla_gpu_devices])

        if isinstance(self.tf_option, (str, bytes)):
            import tensorflow as tf
            tf_option = tf.compat.v1.ConfigProto.FromString(self.tf_option)
        elif self.tf_option is not None:
            tf_option = self.tf_option
        else:
            tf_option = tensorflow_options()

        dask_context = {
            'model_task':
            '%s, %s' % (self.key, ','.join(using_xla_gpu_device_names)),
            'model_addr':
            tensorflow_addr,
            'worker_addr':
            self.address,
            'schduler_addr':
            self.scheduler_info,
            'workspace':
            DASK_WORKSPACE,
            'local_dir':
            self.dask_cwd,
            'pid':
            os.getpid(),
            'thread_id':
            self.thrid,
            'code':
            0,
        }

        env_dict = self.tensorflow_env(tf_option,
                                       self.tf_config,
                                       dask_context,
                                       cuda_indexes=cuda_indexes)
        cmd = [
            sys.executable, r'-u', env_dict['TF_MODEL_ENTRYPOINT'], self.key
        ]
        fmt = 'Model Start, key:%s,\n  cmd:%s\n  dask_context:%s\n  sys.path:%s\n  tf_option:%s\n  tf_config:%s\n\n'
        self.log(fmt % (self.key, cmd, dask_context, self.sys_path, tf_option,
                        self.tf_config))

        for k, v in env_dict.items():
            if not isinstance(k, str) or not isinstance(v, str):
                self.log('Error env k:%s, v:%s\n' % (k, v))

        exe_package = partial(process.Subprocess,
                              cmd,
                              executable=DASK_PYTHON_INTERPRETER,
                              cwd=env_dict['TF_DATA_DIR'],
                              env=env_dict,
                              preexec_fn=None,
                              stdin=self.stdin,
                              stdout=self.stdout,
                              stderr=self.stdout,
                              encoding=sys.getdefaultencoding(),
                              pass_fds=(self.stdin.fileno(),
                                        self.stdout.fileno()),
                              universal_newlines=False,
                              bufsize=0,
                              restore_signals=False,
                              start_new_session=False)

        return exe_package

    def flight(self, exe_package):
        # flighting in main thread, since `SIGCHLD` MUST received in it; and then correctly call exit callback.
        self.exe = exe_package()
        self.exe.set_exit_callback(self.landed)
        msg = '\n ***** Tensorflow Task   Inited, key:%s, sub:%s, pid:%s ***** ' % (
            self.key, self.exe.pid, os.getpid())
        self.log(msg)

    def landed(self, retval=0):
        self.log("worker pub msg: %s", {self.key: retval})
        self.result.put({self.key: retval})
        ident = yield self.dask_worker.scheduler.identity()
        msg = yield self.sub._get(timeout=10)
        self.log('Tensorflow Push Message Received, sub:%s, msg:%s, ident:%s' %
                 (self.key, msg, ident))
        msg = '\n ***** Tensorflow Task Finished, key:%s, ret:%s, tid:%s, pid:%s, ***** \n\n' % (
            self.key, retval, threading.get_ident(), os.getpid())
        self.log(msg)
        self.recovery()

    def recovery(self):
        if self.sys_stdio is None:
            return

        self.log('State Recovery:%s', self.key)
        os.chdir(self.dask_cwd)
        sys.__stdin__, sys.__stdout__, sys.__stderr__, sys.stdin, sys.stdout, sys.stderr = self.sys_stdio
        sys.path, sys.argv = self.sys_path, self.sys_argv

        if self.stdin:
            if self.stdin != sys.__stdin__:
                self.stdin.close()
            else:
                self.stdin.flush()

        if self.stdout:
            if self.stdout != sys.__stdout__:
                self.stdout.close()
            else:
                self.stdout.flush()

        self.stdin = self.stdout = self.sys_stdio = self.sys_path = self.sys_argv = self.dask_worker = None
        del self.result
        del self.exe
        del self.sub