def deserialize_numpy_ndarray(header, frames): with log_errors(): if header.get("pickle"): return pickle.loads(frames[0], buffers=frames[1:]) (frame, ) = frames (writeable, ) = header["writeable"] is_custom, dt = header["dtype"] if is_custom: dt = pickle.loads(dt) else: dt = np.dtype(dt) if header.get("broadcast_to"): shape = header["broadcast_to"] else: shape = header["shape"] x = np.ndarray(shape, dtype=dt, buffer=frame, strides=header["strides"]) if not writeable: x.flags.writeable = False else: x = np.require(x, requirements=["W"]) return x
def test_pickle_numpy(): np = pytest.importorskip('numpy') x = np.ones(5) assert (loads(dumps(x)) == x).all() x = np.ones(5000) assert (loads(dumps(x)) == x).all()
def test_pickle_numpy(): np = pytest.importorskip("numpy") x = np.ones(5) assert (loads(dumps(x)) == x).all() x = np.ones(5000) assert (loads(dumps(x)) == x).all()
def test_pickle_numpy(): np = pytest.importorskip("numpy") x = np.ones(5) assert (loads(dumps(x)) == x).all() assert (deserialize(*serialize(x, serializers=("pickle", ))) == x).all() x = np.ones(5000) assert (loads(dumps(x)) == x).all() assert (deserialize(*serialize(x, serializers=("pickle", ))) == x).all() if HIGHEST_PROTOCOL >= 5: x = np.ones(5000) l = [] d = dumps(x, buffer_callback=l.append) assert len(l) == 1 assert isinstance(l[0], PickleBuffer) assert memoryview(l[0]) == memoryview(x) assert (loads(d, buffers=l) == x).all() h, f = serialize(x, serializers=("pickle", )) assert len(f) == 2 assert isinstance(f[0], bytes) assert isinstance(f[1], memoryview) assert (deserialize(h, f) == x).all()
def test_worker(c, a, b): aa = rpc(ip=a.ip, port=a.port) bb = rpc(ip=b.ip, port=b.port) result = yield aa.identity() assert not a.active response = yield aa.compute(key='x', function=dumps(add), args=dumps([1, 2]), who_has={}, close=True) assert not a.active assert response['status'] == 'OK' assert a.data['x'] == 3 assert isinstance(response['compute_start'], float) assert isinstance(response['compute_stop'], float) assert isinstance(response['thread'], Integral) response = yield bb.compute(key='y', function=dumps(add), args=dumps(['x', 10]), who_has={'x': [a.address]}) assert response['status'] == 'OK' assert b.data['y'] == 13 assert response['nbytes'] == sizeof(b.data['y']) assert isinstance(response['transfer_start'], float) assert isinstance(response['transfer_stop'], float) def bad_func(): 1 / 0 response = yield bb.compute(key='z', function=dumps(bad_func), args=dumps(()), close=True) assert not b.active assert response['status'] == 'error' assert isinstance(loads(response['exception']), ZeroDivisionError) if sys.version_info[0] >= 3: assert any('1 / 0' in line for line in pluck(3, traceback.extract_tb( loads(response['traceback']))) if line) aa.close_rpc() yield a._close() assert a.address not in c.ncores and b.address in c.ncores assert list(c.ncores.keys()) == [b.address] assert isinstance(b.address, str) assert b.ip in b.address assert str(b.port) in b.address bb.close_rpc()
def test_pickle_numpy(protocol): np = pytest.importorskip("numpy") context = {"pickle-protocol": protocol} x = np.ones(5) assert (loads(dumps(x, protocol=protocol)) == x).all() assert ( deserialize(*serialize(x, serializers=("pickle",), context=context)) == x ).all() x = np.ones(5000) assert (loads(dumps(x, protocol=protocol)) == x).all() assert ( deserialize(*serialize(x, serializers=("pickle",), context=context)) == x ).all() x = np.array([np.arange(3), np.arange(4, 6)], dtype=object) x2 = loads(dumps(x, protocol=protocol)) assert x.shape == x2.shape assert x.dtype == x2.dtype assert x.strides == x2.strides for e_x, e_x2 in zip(x.flat, x2.flat): np.testing.assert_equal(e_x, e_x2) h, f = serialize(x, serializers=("pickle",), context=context) if protocol >= 5: assert len(f) == 3 else: assert len(f) == 1 x3 = deserialize(h, f) assert x.shape == x3.shape assert x.dtype == x3.dtype assert x.strides == x3.strides for e_x, e_x3 in zip(x.flat, x3.flat): np.testing.assert_equal(e_x, e_x3) if protocol >= 5: x = np.ones(5000) l = [] d = dumps(x, protocol=protocol, buffer_callback=l.append) assert len(l) == 1 assert isinstance(l[0], pickle.PickleBuffer) assert memoryview(l[0]) == memoryview(x) assert (loads(d, buffers=l) == x).all() h, f = serialize(x, serializers=("pickle",), context=context) assert len(f) == 2 assert isinstance(f[0], bytes) assert isinstance(f[1], memoryview) assert (deserialize(h, f) == x).all()
async def plugin_add(self, plugin=None, name=None): with log_errors(pdb=False): if isinstance(plugin, bytes): plugin = pickle.loads(plugin) if name is None: name = _get_plugin_name(plugin) assert name self.plugins[name] = plugin logger.info("Starting Nanny plugin %s" % name) if hasattr(plugin, "setup"): try: result = plugin.setup(nanny=self) if isawaitable(result): result = await result except Exception as e: msg = error_message(e) return msg if getattr(plugin, "restart", False): await self.restart() return {"status": "OK"}
def test_pickle_functions(): value = 1 def f(x): # closure return x + value for func in [f, lambda x: x + 1, partial(add, 1)]: assert loads(dumps(func))(1) == func(1)
def test_pickle_functions(protocol): context = {"pickle-protocol": protocol} def make_closure(): value = 1 def f(x): # closure return x + value return f def funcs(): yield make_closure() yield (lambda x: x + 1) yield partial(add, 1) for func in funcs(): wr = weakref.ref(func) func2 = loads(dumps(func, protocol=protocol)) wr2 = weakref.ref(func2) assert func2(1) == func(1) func3 = deserialize(*serialize(func, serializers=("pickle",), context=context)) wr3 = weakref.ref(func3) assert func3(1) == func(1) del func, func2, func3 gc.collect() assert wr() is None assert wr2() is None assert wr3() is None
def test_pickle_data(protocol): context = {"pickle-protocol": protocol} data = [1, b"123", "123", [123], {}, set()] for d in data: assert loads(dumps(d, protocol=protocol)) == d assert deserialize(*serialize(d, serializers=("pickle",), context=context)) == d
def deserialize(cls, serialized_flow: str) -> "Flow": import base64 import zlib from distributed.protocol.pickle import loads compressed_flow = base64.decodebytes(serialized_flow.encode()) initialized_flow: 'Flow' = loads(zlib.decompress(compressed_flow)) assert initialized_flow.initialized return initialized_flow
def test_pickle_out_of_band(): class MemoryviewHolder: def __init__(self, mv): self.mv = memoryview(mv) def __reduce_ex__(self, protocol): if protocol >= 5: return MemoryviewHolder, (pickle.PickleBuffer(self.mv), ) else: return MemoryviewHolder, (self.mv.tobytes(), ) mv = memoryview(b"123") mvh = MemoryviewHolder(mv) if HIGHEST_PROTOCOL >= 5: l = [] d = dumps(mvh, buffer_callback=l.append) mvh2 = loads(d, buffers=l) assert len(l) == 1 assert isinstance(l[0], pickle.PickleBuffer) assert memoryview(l[0]) == mv else: mvh2 = loads(dumps(mvh)) assert isinstance(mvh2, MemoryviewHolder) assert isinstance(mvh2.mv, memoryview) assert mvh2.mv == mv h, f = serialize(mvh, serializers=("pickle", )) mvh3 = deserialize(h, f) assert isinstance(mvh3, MemoryviewHolder) assert isinstance(mvh3.mv, memoryview) assert mvh3.mv == mv if HIGHEST_PROTOCOL >= 5: assert len(f) == 2 assert isinstance(f[0], bytes) assert isinstance(f[1], memoryview) assert f[1] == mv else: assert len(f) == 1 assert isinstance(f[0], bytes)
def transition(self, key, start, finish, *args, **kwargs): if finish in status_finish: function_name, emit_id, _ = key.split('--') ts = self.worker.tasks[key] exc = ts.exception trace = ts.traceback trace = traceback.format_tb(trace.data) error = str(exc.data) if 'SerializedTask' in str(type(ts.runspec)): function = loads(ts.runspec.function) args = ts.runspec.args args = loads(args) if args else () kwargs = ts.runspec.kwargs kwargs = loads(kwargs) if kwargs else () else: function, args, kwargs = ts.runspec function_name = f'{function.__module__}.{function.__name__}' data = {} for dep in ts.dependencies: try: data[dep.key] = self.worker.data[dep.key] except BaseException: pass args2 = pack_data(args, data, key_types=(bytes, str)) kwargs2 = pack_data(kwargs, data, key_types=(bytes, str)) data = { 'function': function_name, 'args': args2, 'kwargs': kwargs2, 'exception': trace, 'error': error, 'key': key, 'emit_id': emit_id } self.persist(data, datetime.now())
def test_pickle_out_of_band(protocol): context = {"pickle-protocol": protocol} mv = memoryview(b"123") mvh = MemoryviewHolder(mv) if protocol >= 5: l = [] d = dumps(mvh, protocol=protocol, buffer_callback=l.append) mvh2 = loads(d, buffers=l) assert len(l) == 1 assert isinstance(l[0], pickle.PickleBuffer) assert memoryview(l[0]) == mv else: mvh2 = loads(dumps(mvh, protocol=protocol)) assert isinstance(mvh2, MemoryviewHolder) assert isinstance(mvh2.mv, memoryview) assert mvh2.mv == mv h, f = serialize(mvh, serializers=("pickle",), context=context) mvh3 = deserialize(h, f) assert isinstance(mvh3, MemoryviewHolder) assert isinstance(mvh3.mv, memoryview) assert mvh3.mv == mv if protocol >= 5: assert len(f) == 2 assert isinstance(f[0], bytes) assert isinstance(f[1], memoryview) assert f[1] == mv else: assert len(f) == 1 assert isinstance(f[0], bytes)
def deserialize_numpy_maskedarray(header, frames): data_header = header["data-header"] data_frames = frames[:header["nframes"]] data = deserialize_numpy_ndarray(data_header, data_frames) if "mask-header" in header: mask_header = header["mask-header"] mask_frames = frames[header["nframes"]:] mask = deserialize_numpy_ndarray(mask_header, mask_frames) else: mask = np.ma.nomask pickled_fv, fill_value = header["fill-value"] if pickled_fv: fill_value = pickle.loads(fill_value) return np.ma.masked_array(data, mask=mask, fill_value=fill_value)
def deserialize(self, header, frames): cls = pickle.loads(header["type-serialized"]) if issubclass(cls, dict): dd = obj = {} else: obj = object.__new__(cls) dd = obj.__dict__ dd.update(header["simple"]) for k, d in header["complex"].items(): h = d["header"] f = frames[d["start"] : d["stop"]] nested_dict = h.get("nested-dict") if nested_dict: v = self.deserialize(nested_dict, f) else: v = deserialize(h, f) dd[k] = v return obj
def test_pickle_functions(): def make_closure(): value = 1 def f(x): # closure return x + value return f def funcs(): yield make_closure() yield (lambda x: x + 1) yield partial(add, 1) for func in funcs(): wr = weakref.ref(func) func2 = loads(dumps(func)) wr2 = weakref.ref(func2) assert func2(1) == func(1) del func, func2 assert wr() is None assert wr2() is None
def pickle_loads(header, frames): x, buffers = frames[0], frames[1:] writeable = header.get("writeable") if not writeable: writeable = len(buffers) * (None,) new = [] memoryviews = map(memoryview, buffers) for w, mv in zip(writeable, memoryviews): if w == mv.readonly: if mv.readonly: mv = memoryview(bytearray(mv)) else: mv = memoryview(bytes(mv)) if mv.nbytes > 0: mv = mv.cast(mv.format, mv.shape) else: mv = mv.cast(mv.format) new.append(mv) return pickle.loads(x, buffers=new)
def _decode_default(obj): offset = obj.get("__Serialized__", 0) if offset > 0: sub_header = msgpack.loads( frames[offset], object_hook=msgpack_decode_default, use_list=False, **msgpack_opts, ) offset += 1 sub_frames = frames[offset:offset + sub_header["num-sub-frames"]] if deserialize: if "compression" in sub_header: sub_frames = decompress(sub_header, sub_frames) return merge_and_deserialize(sub_header, sub_frames, deserializers=deserializers) else: return Serialized(sub_header, sub_frames) offset = obj.get("__Pickled__", 0) if offset > 0: sub_header = msgpack.loads(frames[offset]) offset += 1 sub_frames = frames[offset:offset + sub_header["num-sub-frames"]] if allow_pickle: return pickle.loads(sub_header["pickled-obj"], buffers=sub_frames) else: raise ValueError( "Unpickle on the Scheduler isn't allowed, set `distributed.scheduler.pickle=true`" ) return msgpack_decode_default(obj)
def func(dask_scheduler, plugin=None): p = loads(plugin) # deserialize dask_scheduler.add_plugin(p) p.register_handlers(dask_scheduler)
def serialize_numpy_ndarray(x, context=None): if x.dtype.hasobject or (x.dtype.flags & np.core.multiarray.LIST_PICKLE): header = {"pickle": True} frames = [None] buffer_callback = lambda f: frames.append(memoryview(f)) frames[0] = pickle.dumps( x, buffer_callback=buffer_callback, protocol=(context or {}).get("pickle-protocol", None), ) return header, frames # We cannot blindly pickle the dtype as some may fail pickling, # so we have a mixture of strategies. if x.dtype.kind == "V": # Preserving all the information works best when pickling try: # Only use stdlib pickle as cloudpickle is slow when failing # (microseconds instead of nanoseconds) dt = ( 1, pickle.pickle.dumps(x.dtype, protocol=(context or {}).get( "pickle-protocol", None)), ) pickle.loads(dt[1]) # does it unpickle fine? except Exception: # dtype fails pickling => fall back on the descr if reasonable. if x.dtype.type is not np.void or x.dtype.alignment != 1: raise else: dt = (0, x.dtype.descr) else: dt = (0, x.dtype.str) # Only serialize broadcastable data for arrays with zero strided axes broadcast_to = None if 0 in x.strides: broadcast_to = x.shape strides = x.strides writeable = x.flags.writeable x = x[tuple(slice(None) if s != 0 else slice(1) for s in strides)] if not x.flags.c_contiguous and not x.flags.f_contiguous: # Broadcasting can only be done with contiguous arrays x = np.ascontiguousarray(x) x = np.lib.stride_tricks.as_strided( x, strides=[ j if i != 0 else i for i, j in zip(strides, x.strides) ], writeable=writeable, ) if not x.shape: # 0d array strides = x.strides data = x.ravel() elif x.flags.c_contiguous or x.flags.f_contiguous: # Avoid a copy and respect order when unserializing strides = x.strides data = x.ravel(order="K") else: x = np.ascontiguousarray(x) strides = x.strides data = x.ravel() if data.dtype.fields or data.dtype.itemsize > 8: data = data.view("u%d" % math.gcd(x.dtype.itemsize, 8)) try: data = data.data except ValueError: # "ValueError: cannot include dtype 'M' in a buffer" data = data.view("u%d" % math.gcd(x.dtype.itemsize, 8)).data header = { "dtype": dt, "shape": x.shape, "strides": strides, "writeable": [x.flags.writeable], } if broadcast_to is not None: header["broadcast_to"] = broadcast_to frames = [data] return header, frames
def test_worker_bad_args(c, a, b): aa = rpc(ip=a.ip, port=a.port) bb = rpc(ip=b.ip, port=b.port) class NoReprObj(object): """ This object cannot be properly represented as a string. """ def __str__(self): raise ValueError("I have no str representation.") def __repr__(self): raise ValueError("I have no repr representation.") response = yield aa.compute(key='x', function=dumps(NoReprObj), args=dumps(()), who_has={}) assert not a.active assert response['status'] == 'OK' assert a.data['x'] assert isinstance(response['compute_start'], float) assert isinstance(response['compute_stop'], float) assert isinstance(response['thread'], Integral) def bad_func(*args, **kwargs): 1 / 0 class MockLoggingHandler(logging.Handler): """Mock logging handler to check for expected logs.""" def __init__(self, *args, **kwargs): self.reset() logging.Handler.__init__(self, *args, **kwargs) def emit(self, record): self.messages[record.levelname.lower()].append(record.getMessage()) def reset(self): self.messages = { 'debug': [], 'info': [], 'warning': [], 'error': [], 'critical': [], } hdlr = MockLoggingHandler() old_level = logger.level logger.setLevel(logging.DEBUG) logger.addHandler(hdlr) response = yield bb.compute(key='y', function=dumps(bad_func), args=dumps(['x']), kwargs=dumps({'k': 'x'}), who_has={'x': [a.address]}) assert not b.active assert response['status'] == 'error' # Make sure job died because of bad func and not because of bad # argument. assert isinstance(loads(response['exception']), ZeroDivisionError) if sys.version_info[0] >= 3: assert any('1 / 0' in line for line in pluck(3, traceback.extract_tb( loads(response['traceback']))) if line) assert hdlr.messages['warning'][0] == " Compute Failed\n" \ "Function: bad_func\n" \ "args: (< could not convert arg to str >)\n" \ "kwargs: {'k': < could not convert arg to str >}\n" assert re.match(r"^Send compute response to scheduler: y, " \ "\{.*'args': \(< could not convert arg to str >\), .*" \ "'kwargs': \{'k': < could not convert arg to str >\}.*\}", hdlr.messages['debug'][0]) or \ re.match("^Send compute response to scheduler: y, " \ "\{.*'kwargs': \{'k': < could not convert arg to str >\}, .*" \ "'args': \(< could not convert arg to str >\).*\}", hdlr.messages['debug'][0]) logger.setLevel(old_level) # Now we check that both workers are still alive. assert not a.active response = yield aa.compute(key='z', function=dumps(add), args=dumps([1, 2]), who_has={}, close=True) assert not a.active assert response['status'] == 'OK' assert a.data['z'] == 3 assert isinstance(response['compute_start'], float) assert isinstance(response['compute_stop'], float) assert isinstance(response['thread'], Integral) assert not b.active response = yield bb.compute(key='w', function=dumps(add), args=dumps([1, 2]), who_has={}, close=True) assert not b.active assert response['status'] == 'OK' assert b.data['w'] == 3 assert isinstance(response['compute_start'], float) assert isinstance(response['compute_stop'], float) assert isinstance(response['thread'], Integral) aa.close_rpc() bb.close_rpc()
def test_pickle_data(): data = [1, b"123", "123", [123], {}, set()] for d in data: assert loads(dumps(d)) == d
def test_pickle_data(): data = [1, b'123', '123', [123], {}, set()] for d in data: assert loads(dumps(d)) == d
def test_pickle_data(): data = [1, b"123", "123", [123], {}, set()] for d in data: assert loads(dumps(d)) == d assert deserialize(*serialize(d, serializers=("pickle", ))) == d
def dask_loads(header, frames): typ = pickle.loads(header["type-serialized"]) loads = dask_deserialize.dispatch(typ) return loads(header["sub-header"], frames)