def test_timeouts(c, s, a, b): sub = Sub('a', client=c, worker=None) start = time() with pytest.raises(TimeoutError): yield sub.get(timeout=0.1) stop = time() assert stop - start < 1
def test_timeouts(c, s, a, b): sub = Sub("a", client=c, worker=None) start = time() with pytest.raises(TimeoutError): yield sub.get(timeout=0.1) stop = time() assert stop - start < 1
def test_client_worker(c, s, a, b): sub = Sub("a", client=c, worker=None) def f(x): pub = Pub("a") pub.put(x) futures = c.map(f, range(10)) yield wait(futures) L = [] for i in range(10): result = yield sub.get() L.append(result) assert set(L) == set(range(10)) sps = s.extensions["pubsub"] aps = a.extensions["pubsub"] bps = b.extensions["pubsub"] start = time() while (sps.publishers["a"] or sps.subscribers["a"] or aps.publishers["a"] or bps.publishers["a"] or len(sps.client_subscribers["a"]) != 1): yield gen.sleep(0.01) assert time() < start + 3 del sub start = time() while (sps.client_subscribers or any(aps.publish_to_scheduler.values()) or any(bps.publish_to_scheduler.values())): yield gen.sleep(0.01) assert time() < start + 3
def fake_sidecar_fct( docker_auth: DockerBasicAuth, service_key: str, service_version: str, input_data: TaskInputData, output_data_keys: TaskOutputDataSchema, log_file_url: AnyUrl, command: List[str], expected_annotations: Dict[str, Any], ) -> TaskOutputData: sub = Sub(TaskCancelEvent.topic_name()) # get the task data worker = get_worker() task = worker.tasks.get(worker.get_current_task()) assert task is not None print(f"--> task {task=} started") assert task.annotations == expected_annotations # sleep a bit in case someone is aborting us print("--> waiting for task to be aborted...") for msg in sub: assert msg print(f"--> received cancellation msg: {msg=}") cancel_event = TaskCancelEvent.parse_raw(msg) # type: ignore assert cancel_event if cancel_event.job_id == task.key: print("--> raising cancellation error now") raise asyncio.CancelledError("task cancelled") return TaskOutputData.parse_obj({"some_output_key": 123})
async def test_repr(c, s, a, b): pub = Pub("my-topic") sub = Sub("my-topic") assert "my-topic" in str(pub) assert "Pub" in str(pub) assert "my-topic" in str(sub) assert "Sub" in str(sub)
def test_client(c, s): with pytest.raises(Exception): get_worker() sub = Sub("a") pub = Pub("a") sps = s.extensions["pubsub"] cps = c.extensions["pubsub"] start = time() while not set(sps.client_subscribers["a"]) == {c.id}: yield gen.sleep(0.01) assert time() < start + 3 pub.put(123) result = yield sub.__anext__() assert result == 123
def test_client(c, s): with pytest.raises(Exception): get_worker() sub = Sub('a') pub = Pub('a') sps = s.extensions['pubsub'] cps = c.extensions['pubsub'] start = time() while not set(sps.client_subscribers['a']) == {c.id}: yield gen.sleep(0.01) assert time() < start + 3 pub.put(123) result = yield sub.__anext__() assert result == 123
async def test_timeouts(c, s, a, b): sub = Sub("a", client=c, worker=None) start = time() with pytest.raises(TimeoutError): await sub.get(timeout="100ms") stop = time() assert stop - start < 1 with pytest.raises(TimeoutError): await sub.get(timeout=timedelta(milliseconds=10))
def test_client_worker(c, s, a, b): sub = Sub('a', client=c, worker=None) def f(x): pub = Pub('a') pub.put(x) futures = c.map(f, range(10)) yield wait(futures) L = [] for i in range(10): result = yield sub.get() L.append(result) assert set(L) == set(range(10)) sps = s.extensions['pubsub'] aps = a.extensions['pubsub'] bps = b.extensions['pubsub'] start = time() while (sps.publishers['a'] or sps.subscribers['a'] or aps.publishers['a'] or bps.publishers['a'] or len(sps.client_subscribers['a']) != 1): yield gen.sleep(0.01) assert time() < start + 3 del sub start = time() while (sps.client_subscribers or any(aps.publish_to_scheduler.values()) or any(bps.publish_to_scheduler.values())): yield gen.sleep(0.01) assert time() < start + 3
def is_current_task_aborted(sub: distributed.Sub) -> bool: task: Optional[TaskState] = _get_current_task_state() logger.debug("found following TaskState: %s", task) if task is None: # the task was removed from the list of tasks this worker should work on, meaning it is aborted # NOTE: this does not work in distributed mode, hence we need to use Variables,or PubSub return True with suppress(asyncio.TimeoutError): msg = sub.get(timeout="100ms") if msg: cancel_event = TaskCancelEvent.parse_raw(msg) # type: ignore return bool(cancel_event.job_id == task.key) return False
def pingpong(a, b, start=False, n=1000, msg=1): sub = Sub(a) pub = Pub(b) while not pub.subscribers: sleep(0.01) if start: pub.put(msg) # other sub may not have started yet for i in range(n): msg = next(sub) pub.put(msg) # if i % 100 == 0: # print(a, b, i) return n
def tensorflow_scheduler(global_future, model_name, client=None, tf_option=None, tf_port=None, **tf_cluster_spec): scheduler_info = yield client.scheduler.identity() cuda_free_map = yield client.run(cuda_free_indexes) tf_configs = tensorflow_gen_config(free_node_name_map=cuda_free_map, **tf_cluster_spec) logger.info('Model Schedule %s: \n tf_configs:%s\n\n', model_name, tf_configs) tf_option = tf_option if isinstance(tf_option, (str, bytes)) else ( tf_option.SerializeToString() if tf_option else tf_option) chief_configs, ps_configs, other_configs = [], [], [] for tf_config in tf_configs: task_type = tf_config['task']['type'] task_index = tf_config['task']['index'] if task_type in ('chief', 'master'): chief_configs.append(tf_config) elif task_type in ('ps', ): ps_configs.append(tf_config) else: other_configs.append(tf_config) s_time = time.time() dt = datetime.datetime.now() chief_configs.sort(key=lambda cfg: cfg['task']['index']) ps_configs.sort(key=lambda cfg: cfg['task']['index']) other_configs.sort(key=lambda cfg: cfg['task']['index']) client.loop.set_default_executor( ThreadPoolExecutor(max_workers=len(tf_configs))) result_future = Future() result_future.tf_configs = tf_configs result_future.tf_option = tf_option result_future.cuda_map = cuda_free_map chief_future = Future() client.loop.add_callback(startup_actors, scheduler_info, client, model_name, tf_option, tf_configs, chief_future) chief_actors = yield chief_future sorted_task_keys = list( sorted(chief_actors.keys(), key=lambda x: dask_sork_key(x))) sub = Sub(model_name, client=client) pubs = {k: Pub(model_name, client=client) for k in sorted_task_keys} scheduler_info = yield client.scheduler.identity( ) # data flush sync between this client and scheduler def chief_finish(task_key, actor, fu): value = fu.result() logger.info('Tensorflow Finished[%s/%s], key:%s, val:%s', len(chief_actors), len(tf_configs), task_key, value) chief_actors[task_key] = actor if len(chief_actors) == len(tf_configs): logger.info('Tensorflow Cluster All Finished: %s', chief_actors.keys()) # Chief First. msgs = {} chief_key_actor = sorted_task_keys[0] while (len(msgs) + 1) < len(chief_actors): msg = yield sub._get() logger.info('Sub Rcv %s:%s', type(msg), msg) msgs.update(msg) import pdb pdb.set_trace() # A = yield chief_actor.get_result() assert chief_key_actor in msgs, 'Tensorflow Chief Task Required: %s' % chief_key_actor time.sleep(1) future = yield model_cleanup(client, model_name) import pdb pdb.set_trace() logger.info("Tensorflow Task clean, %s", chief_actors) global_future.set_result(chief_actors)
def sample(density, client=None, n_chain=4, n_iter=None, n_warmup=None, trace=None, random_state=None, x_0=None, verbose=True, return_trace=False, density_options={}, sampler_options={}, sampler='NUTS'): # DEVELOPMENT NOTES # if use_surrogate is not specified in density_options # x_0 is interpreted as in original space and will be transformed # otherwise, x_0 is understood as in the specified space if not isinstance(density, (Density, DensityLite)): raise ValueError('density should be a Density or DensityLite.') if isinstance(trace, Trace): trace = [trace] if (hasattr(trace, '__iter__') and len(trace) > 0 and all(isinstance(t, Trace) for t in trace)): n_chain = len(trace) else: try: n_chain = int(n_chain) assert n_chain > 0 except: raise ValueError('invalid value for n_chain') trace = [None for i in range(n_chain)] if n_iter is None: n_iter = 3000 if (trace[0] is None) else 1000 else: try: n_iter = int(n_iter) assert n_iter > 0 except: raise ValueError('invalid value for n_iter.') if n_warmup is None: n_warmup = 1000 if (trace[0] is None) else 0 else: try: n_warmup = int(n_warmup) assert n_warmup > 0 except: raise ValueError('invalid value for n_warmup.') try: density_options = dict(density_options).copy() except: raise ValueError('density_options should be a dict.') if not 'use_surrogate' in density_options: density_options['use_surrogate'] = True if not 'original_space' in density_options: density_options['original_space'] = False _transform_x = True else: _transform_x = False if trace[0] is None: if hasattr(random_state, '__iter__'): random_state = [bfrandom.check_state(rs) for rs in random_state] if len(random_state) < n_chain: raise ValueError('you did not give me enough random_state(s).') else: random_state = bfrandom.check_state(random_state) random_state = bfrandom.split_state(random_state, n_chain) if x_0 is None: dim = density.input_size x_0 = bfrandom.multivariate_normal(np.zeros(dim), np.eye(dim), n_chain) else: x_0 = np.atleast_2d(x_0) if x_0.shape[0] < n_chain: raise ValueError('you did not give me enough x_0(s).') x_0 = x_0[:n_chain, :] if _transform_x: x_0 = density.from_original(x_0) else: random_state = [None for i in range(n_chain)] x_0 = [None for i in range(n_chain)] try: client, _new_client = check_client(client) # dask_key = bfrandom.string() dask_key = 'BayesFast-' + client.id sub = Sub(dask_key) finished = 0 if sampler == 'NUTS': def nuts_worker(i): with threadpool_limits(1): def logp_and_grad(x): return density.logp_and_grad(x, **density_options) nuts = NUTS(logp_and_grad=logp_and_grad, trace=trace[i], dask_key=dask_key, chain_id=i, random_state=random_state[i], x_0=x_0[i], **sampler_options) t = nuts.run(n_iter, n_warmup, verbose) return t # if return_trace else t.get() foo = client.map(nuts_worker, range(n_chain)) for msg in sub: if not hasattr(msg, '__iter__'): warnings.warn('unexpected message: {}.'.format(msg), RuntimeWarning) elif msg[0] == 'Error': break elif isclass(msg[0]) and issubclass(msg[0], Warning): warnings.warn(msg[1], msg[0]) elif msg[0] == 'SamplingProceeding': print(msg[1]) elif msg[0] == 'SamplingFinished': print(msg[1]) finished += 1 else: warnings.warn('unexpected message: {}.'.format(msg), RuntimeWarning) if finished == n_chain: break tt = client.gather(foo) if _transform_x: xx = np.array([density.to_original(t.get()) for t in tt]) else: xx = np.array([t.get() for t in tt]) return (xx, tt) if return_trace else xx elif sampler == 'HMC': raise NotImplementedError elif sampler == 'EnsembleSampler': raise NotImplementedError else: raise ValueError( 'Sorry I do not know how to do {}.'.format(sampler)) finally: if _new_client: client.cluster.close() client.close()
def f(_): sub = Sub("a") return list(toolz.take(5, sub))
def sample(density, sample_trace=None, sampler='NUTS', n_run=None, parallel_backend=None, verbose=True): if not isinstance(density, (Density, DensityLite)): raise ValueError('density should be a Density or DensityLite.') if isinstance(sample_trace, NTrace): sampler = 'NUTS' elif isinstance(sample_trace, HTrace): sampler = 'HMC' elif isinstance(sample_trace, TNTrace): sampler = 'TNUTS' elif isinstance(sample_trace, THTrace): sampler = 'THMC' elif isinstance(sample_trace, ETrace): raise NotImplementedError elif sample_trace is None or isinstance(sample_trace, dict): sample_trace = {} if (sample_trace is None) else sample_trace if sampler == 'NUTS': sample_trace = NTrace(**sample_trace) elif sampler == 'HMC': sample_trace = HTrace(**sample_trace) elif sampler == 'TNUTS': sample_trace = TNTrace(**sample_trace) elif sampler == 'THMC': sample_trace = THTrace(**sample_trace) elif sampler == 'Ensemble': raise NotImplementedError else: raise ValueError('unexpected value for sampler.') elif isinstance(sample_trace, TraceTuple): sampler = sample_trace.sampler if any(sampler == _ for _ in ('NUTS', 'HMC', 'TNUTS', 'THMC')): pass elif sampler == 'Ensemble': raise NotImplementedError else: raise ValueError('unexpected value for sample_trace.sampler.') else: raise ValueError('unexpected value for sample_trace.') if isinstance(sample_trace, SampleTrace): if sample_trace.random_generator is None: sample_trace.random_generator = get_generator() get_generator().normal() if sample_trace.x_0 is None: dim = density.input_size if dim is None: raise RuntimeError('Neither SampleTrace.x_0 nor Density' '/DensityLite.input_size is defined.') sample_trace._x_0 = multivariate_normal(np.zeros(dim), np.eye(dim), sample_trace.n_chain) sample_trace._x_0_transformed = True elif not sample_trace.x_0_transformed: sample_trace._x_0 = density.from_original(sample_trace._x_0) sample_trace._x_0_transformed = True if parallel_backend is None: parallel_backend = get_backend() else: parallel_backend = ParallelBackend(parallel_backend) if parallel_backend.kind == 'multiprocess': use_dask = False dask_key = None process_lock = Manager().Lock() elif parallel_backend.kind == 'ray': use_dask = False dask_key = None process_lock = None elif parallel_backend.kind == 'dask': if not HAS_DASK: raise RuntimeError( 'you want me to use dask but have not installed ' 'it.') use_dask = True dask_key = 'BayesFast-' + parallel_backend.backend.id process_lock = None sub = Sub(dask_key) finished = 0 elif parallel_backend.kind == 'sharedmem': use_dask = False dask_key = None process_lock = None elif parallel_backend.kind == 'loky': use_dask = False dask_key = None process_lock = None # elif parallel_backend.kind == 'serial': # use_dask = False # dask_key = None # process_lock = None else: raise RuntimeError('unexpected value for parallel_backend.kind.') def nested_helper(sample_trace, i): """Without this, there will be an UnboundLocalError.""" if isinstance(sample_trace, SampleTrace): sample_trace._init_chain(i) elif isinstance(sample_trace, TraceTuple): sample_trace = sample_trace.sample_traces[i] else: raise RuntimeError('unexpected type for sample_trace.') return sample_trace def _sampler_worker(i, sampler_class): try: with threadpool_limits(1): _sample_trace = nested_helper(sample_trace, i) def logp_and_grad(x): return density.logp_and_grad(x, original_space=False) _sampler = sampler_class(logp_and_grad=logp_and_grad, sample_trace=_sample_trace, dask_key=dask_key, process_lock=process_lock) t = _sampler.run(n_run, verbose) t._samples_original = density.to_original(t.samples) t._logp_original = density.to_original_density( t.logp, x_trans=t.samples) return t except Exception: if use_dask: pub = Pub(dask_key) pub.put(['Error', i]) raise with parallel_backend: if any(sampler == _ for _ in ('NUTS', 'HMC', 'TNUTS', 'THMC')): if use_dask: foo = parallel_backend.map_async(_sampler_worker, range(sample_trace.n_chain), [eval(sampler)] * sample_trace.n_chain) for msg in sub: if not hasattr(msg, '__iter__'): warnings.warn('unexpected message: {}.'.format(msg), RuntimeWarning) elif msg[0] == 'Error': break elif isclass(msg[0]) and issubclass(msg[0], Warning): warnings.warn(msg[1], msg[0]) elif msg[0] == 'SamplingProceeding': print(msg[1]) elif msg[0] == 'SamplingFinished': print(msg[1]) finished += 1 else: warnings.warn('unexpected message: {}.'.format(msg), RuntimeWarning) if finished == sample_trace.n_chain: break tt = parallel_backend.gather(foo) else: tt = parallel_backend.map(_sampler_worker, range(sample_trace.n_chain), [eval(sampler)] * sample_trace.n_chain) return TraceTuple(tt) elif sampler == 'Ensemble': raise NotImplementedError else: raise RuntimeError('unexpected value for sampler.')
def __init__(self, key, *args, tf_config=None, tf_option=None, scheduler_info=None, **kwargs): # here we made this thread an OWNER of this task by secede from it's ThreadPoolExecutor. # NOTE: `thread_pool_secede` ONLY works in NON-coroutine actor_exectutor, ref:worker.actor_execute() self.dask_worker = get_worker() self.thrid = threading.get_ident() thread_pool_secede(adjust=True) self.dask_worker.loop.add_callback(self.dask_worker.transition, thread_state.key, "long-running") self.key = key self.name = self.dask_worker.name self.hostname = socket.gethostname() self.address = self.dask_worker.address self.scheduler_info = scheduler_info model_name = self.key.partition(':')[0] self.model_name = model_name[:-4] if model_name.endswith( '.zip') else model_name self.tf_option = json.loads(tf_option) if isinstance( tf_option, str) else tf_option self.tf_config = json.loads(tf_config) if isinstance( tf_config, str) else tf_config self.dask_cwd = os.path.abspath(os.getcwd()) self.tf_model_pool_dir = os.path.abspath(DASK_MODEL_POOL_DIR) self.tf_data_pool_dir = os.path.abspath(DASK_DATA_POOL_DIR) self.tf_data_dir = os.path.join(self.tf_data_pool_dir, self.model_name) self.tf_config_dir = os.path.join(self.tf_data_dir, 'config') self.tf_save_dir = os.path.join(self.tf_data_dir, 'ckpt') self.tf_log_dir = os.path.join(self.tf_data_dir, 'log') os.system('mkdir -p %r; rm -rf %r; mkdir -p %r %r %r %r' % (self.tf_save_dir, self.tf_save_dir, self.tf_data_dir, self.tf_config_dir, self.tf_save_dir, self.tf_log_dir)) os.chdir(self.tf_data_dir) self.sys_stdio = (sys.__stdin__, sys.__stdout__, sys.__stderr__, sys.stdin, sys.stdout, sys.stderr) self.stdout = open(os.path.join(self.tf_log_dir, '%s.log' % self.key.partition(':')[-1]), 'a+', encoding=sys.stdout.encoding) sys.__stdout__ = sys.__stderr__ = sys.stdout = sys.stderr = self.stdout self.stdin = sys.stdin self.sys_path, self.sys_argv = sys.path[:], sys.argv[:] logger.info( 'Accepted Tensorflow Key:%s, Job:%s, Options:%s, Scheduler:%s', key, tf_config, tf_option, scheduler_info) self.devices = dict(tensorflow_devices()) self.future_chunk_size = DASK_READ_CHUNK_SIZE self.args = args self.kwargs = kwargs self.sub = Sub(self.key, worker=self.dask_worker) self.result = Pub(model_name, worker=self.dask_worker) self.exe = self.preflight() self.dask_worker.loop.add_callback(self.flight, self.exe)
class TFActor(object): def __del__(self): self.recovery() def __str__(self): return "<%s %s>" % (self.__class__.__name__, self.key) def __init__(self, key, *args, tf_config=None, tf_option=None, scheduler_info=None, **kwargs): # here we made this thread an OWNER of this task by secede from it's ThreadPoolExecutor. # NOTE: `thread_pool_secede` ONLY works in NON-coroutine actor_exectutor, ref:worker.actor_execute() self.dask_worker = get_worker() self.thrid = threading.get_ident() thread_pool_secede(adjust=True) self.dask_worker.loop.add_callback(self.dask_worker.transition, thread_state.key, "long-running") self.key = key self.name = self.dask_worker.name self.hostname = socket.gethostname() self.address = self.dask_worker.address self.scheduler_info = scheduler_info model_name = self.key.partition(':')[0] self.model_name = model_name[:-4] if model_name.endswith( '.zip') else model_name self.tf_option = json.loads(tf_option) if isinstance( tf_option, str) else tf_option self.tf_config = json.loads(tf_config) if isinstance( tf_config, str) else tf_config self.dask_cwd = os.path.abspath(os.getcwd()) self.tf_model_pool_dir = os.path.abspath(DASK_MODEL_POOL_DIR) self.tf_data_pool_dir = os.path.abspath(DASK_DATA_POOL_DIR) self.tf_data_dir = os.path.join(self.tf_data_pool_dir, self.model_name) self.tf_config_dir = os.path.join(self.tf_data_dir, 'config') self.tf_save_dir = os.path.join(self.tf_data_dir, 'ckpt') self.tf_log_dir = os.path.join(self.tf_data_dir, 'log') os.system('mkdir -p %r; rm -rf %r; mkdir -p %r %r %r %r' % (self.tf_save_dir, self.tf_save_dir, self.tf_data_dir, self.tf_config_dir, self.tf_save_dir, self.tf_log_dir)) os.chdir(self.tf_data_dir) self.sys_stdio = (sys.__stdin__, sys.__stdout__, sys.__stderr__, sys.stdin, sys.stdout, sys.stderr) self.stdout = open(os.path.join(self.tf_log_dir, '%s.log' % self.key.partition(':')[-1]), 'a+', encoding=sys.stdout.encoding) sys.__stdout__ = sys.__stderr__ = sys.stdout = sys.stderr = self.stdout self.stdin = sys.stdin self.sys_path, self.sys_argv = sys.path[:], sys.argv[:] logger.info( 'Accepted Tensorflow Key:%s, Job:%s, Options:%s, Scheduler:%s', key, tf_config, tf_option, scheduler_info) self.devices = dict(tensorflow_devices()) self.future_chunk_size = DASK_READ_CHUNK_SIZE self.args = args self.kwargs = kwargs self.sub = Sub(self.key, worker=self.dask_worker) self.result = Pub(model_name, worker=self.dask_worker) self.exe = self.preflight() self.dask_worker.loop.add_callback(self.flight, self.exe) def device_info(self, xla=None, gpu=True): if xla is None: return gpu_filter([ v for (x, v) in self.devices.items() if v['name'].find('GPU') >= 0 ], gpu_flag=gpu) elif xla is True: return gpu_filter([ v for (x, v) in self.devices.items() if v['name'].find('XLA') >= 0 ], gpu_flag=gpu) else: return gpu_filter([ v for (x, v) in self.devices.items() if v['name'].find('XLA') < 0 ], gpu_flag=gpu) def tensorflow_env(self, tf_option, tf_config, dask_context, cuda_indexes=None): model_entrypoint = os.path.join(self.tf_model_pool_dir, self.model_name) zip_ep, pkg_ep = model_entrypoint + '.zip', os.path.join( model_entrypoint, '__main__.py') if os.path.exists(pkg_ep) and os.path.isfile(pkg_ep): model_entrypoint = pkg_ep elif os.path.exists(zip_ep) and os.path.isfile(zip_ep): model_entrypoint = zip_ep else: raise Exception(USAGE_INFO) env_dict = {} for key in ('LANG', 'PATH', 'CUDA_HOME', 'LD_LIBRARY_PATH', 'USER', 'HOME', 'HOSTNAME', 'SHELL', 'TERM', 'SHLVL', 'MAIL', 'SSH_CONNECTION', 'SSH_TTY', 'SSH_CLIENT'): val = os.getenv(key) if val is not None: env_dict[key] = val env_dict.update( XLA_FLAGS='--xla_hlo_profile', TF_DASK_PID=str(os.getpid()), RF_DASK_PGRP=str(os.getpgrp()), TF_XLA_FLAGS=("--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", "")), TF_MODEL=self.model_name, TF_CONTEXT=json.dumps(dask_context), TF_CONFIG=json.dumps(tf_config), TF_MODEL_POOL_DIR=self.tf_model_pool_dir, TF_DATA_POOL_DIR=self.tf_data_pool_dir, TF_MODEL_ENTRYPOINT=model_entrypoint, TF_CONFIG_DIR=self.tf_config_dir, TF_DATA_DIR=self.tf_data_dir, TF_LOG_DIR=self.tf_log_dir, TF_SAVE_DIR=self.tf_save_dir, PYTHONPATH=':'.join( [self.tf_model_pool_dir, self.tf_data_dir, self.dask_cwd]), PYTHONHASHSEED=str(int(DASK_PYTHONHASHSEED)), PYTHONIOENCODING=sys.getdefaultencoding(), PYTHONUNBUFFERED='True', ) if cuda_indexes: # we explicitly assign GPU indexes to use; let tensorflow aware of ONLY these indexes env_dict['CUDA_VISIBLE_DEVICES'] = cuda_indexes return env_dict def log(self, msg, *args, flush=True): self.stdout.write((msg % args) if args else msg) if flush: self.stdout.flush() def run_model(self, stdin, stdout, *args, **kwargs): import sys sys.stdin = stdin self.stdout = sys.stdout = sys.stderr = sys.__stdout__ = sys.__stderr__ = stdout self.log('HERE IN ASYNC SUBPROCESS: %s' % os.getpid()) model_name = os.getenv('TF_MODEL') model_entry = os.getenv('TF_MODEL_ENTRYPOINT') if model_entry.endswith('.zip'): model_root, modname = model_entry, '__main__' elif model_entry.endswith('.py'): model_root, modname = os.path.dirname( model_entry), os.path.basename(model_entry).rsplit('.', 1)[0] self.log('HERE IN ASYNC MODEL START, %s, %s' % (modname, model_root)) sys.path.insert(0, model_root) __import__(modname) def preflight(self): # this NODE is selected for this task node_name, node_port, cuda_indexes, dask_url = self.tf_config.pop( 'dask').split(':', 3) job_name, task_index = self.tf_config['task']['type'], self.tf_config[ 'task']['index'] tensorflow_addr = self.tf_config['cluster'][job_name][task_index] using_xla_gpu_devices = self.device_info(xla=True, gpu=True) using_xla_gpu_device_names = sorted( [x['name'] for x in using_xla_gpu_devices]) if isinstance(self.tf_option, (str, bytes)): import tensorflow as tf tf_option = tf.compat.v1.ConfigProto.FromString(self.tf_option) elif self.tf_option is not None: tf_option = self.tf_option else: tf_option = tensorflow_options() dask_context = { 'model_task': '%s, %s' % (self.key, ','.join(using_xla_gpu_device_names)), 'model_addr': tensorflow_addr, 'worker_addr': self.address, 'schduler_addr': self.scheduler_info, 'workspace': DASK_WORKSPACE, 'local_dir': self.dask_cwd, 'pid': os.getpid(), 'thread_id': self.thrid, 'code': 0, } env_dict = self.tensorflow_env(tf_option, self.tf_config, dask_context, cuda_indexes=cuda_indexes) cmd = [ sys.executable, r'-u', env_dict['TF_MODEL_ENTRYPOINT'], self.key ] fmt = 'Model Start, key:%s,\n cmd:%s\n dask_context:%s\n sys.path:%s\n tf_option:%s\n tf_config:%s\n\n' self.log(fmt % (self.key, cmd, dask_context, self.sys_path, tf_option, self.tf_config)) for k, v in env_dict.items(): if not isinstance(k, str) or not isinstance(v, str): self.log('Error env k:%s, v:%s\n' % (k, v)) exe_package = partial(process.Subprocess, cmd, executable=DASK_PYTHON_INTERPRETER, cwd=env_dict['TF_DATA_DIR'], env=env_dict, preexec_fn=None, stdin=self.stdin, stdout=self.stdout, stderr=self.stdout, encoding=sys.getdefaultencoding(), pass_fds=(self.stdin.fileno(), self.stdout.fileno()), universal_newlines=False, bufsize=0, restore_signals=False, start_new_session=False) return exe_package def flight(self, exe_package): # flighting in main thread, since `SIGCHLD` MUST received in it; and then correctly call exit callback. self.exe = exe_package() self.exe.set_exit_callback(self.landed) msg = '\n ***** Tensorflow Task Inited, key:%s, sub:%s, pid:%s ***** ' % ( self.key, self.exe.pid, os.getpid()) self.log(msg) def landed(self, retval=0): self.log("worker pub msg: %s", {self.key: retval}) self.result.put({self.key: retval}) ident = yield self.dask_worker.scheduler.identity() msg = yield self.sub._get(timeout=10) self.log('Tensorflow Push Message Received, sub:%s, msg:%s, ident:%s' % (self.key, msg, ident)) msg = '\n ***** Tensorflow Task Finished, key:%s, ret:%s, tid:%s, pid:%s, ***** \n\n' % ( self.key, retval, threading.get_ident(), os.getpid()) self.log(msg) self.recovery() def recovery(self): if self.sys_stdio is None: return self.log('State Recovery:%s', self.key) os.chdir(self.dask_cwd) sys.__stdin__, sys.__stdout__, sys.__stderr__, sys.stdin, sys.stdout, sys.stderr = self.sys_stdio sys.path, sys.argv = self.sys_path, self.sys_argv if self.stdin: if self.stdin != sys.__stdin__: self.stdin.close() else: self.stdin.flush() if self.stdout: if self.stdout != sys.__stdout__: self.stdout.close() else: self.stdout.flush() self.stdin = self.stdout = self.sys_stdio = self.sys_path = self.sys_argv = self.dask_worker = None del self.result del self.exe del self.sub