async def test_context_specific_serialization(c, s, a, b): register_serialization_family("my-ser", my_dumps, my_loads) try: # Create the object on A, force communication to B x = c.submit(MyObject, x=1, y=2, workers=a.address) y = c.submit(lambda x: x, x, workers=b.address) await wait(y) key = y.key def check(dask_worker): # Get the context from the object stored on B my_obj = dask_worker.data[key] return my_obj.context result = await c.run(check, workers=[b.address]) expected = {"sender": a.address, "recipient": b.address} assert result[ b.address]["sender"]["address"] == a.address # see origin worker z = await y # bring object to local process assert z.x == 1 and z.y == 2 assert z.context["sender"]["address"] == b.address finally: from distributed.protocol.serialize import families del families["my-ser"]
def test_context_specific_serialization(c, s, a, b): register_serialization_family('my-ser', my_dumps, my_loads) try: # Create the object on A, force communication to B x = c.submit(MyObject, x=1, y=2, workers=a.address) y = c.submit(lambda x: x, x, workers=b.address) yield wait(y) key = y.key def check(dask_worker): # Get the context from the object stored on B my_obj = dask_worker.data[key] return my_obj.context result = yield c.run(check, workers=[b.address]) expected = {'sender': a.address, 'recipient': b.address} assert result[b.address]['sender'] == a.address # see origin worker z = yield y # bring object to local process assert z.x == 1 and z.y == 2 assert z.context['sender'] == b.address finally: from distributed.protocol.serialize import families del families['my-ser']
def test_different_compression_families(): """Test serialization of a collection of items that use different compression This scenario happens for instance when serializing collections of cupy and numpy arrays. """ class MyObjWithCompression: pass class MyObjWithNoCompression: pass def my_dumps_compression(obj, context=None): if not isinstance(obj, MyObjWithCompression): raise NotImplementedError() header = {"compression": [True]} return header, [bytes(2**20)] def my_dumps_no_compression(obj, context=None): if not isinstance(obj, MyObjWithNoCompression): raise NotImplementedError() header = {"compression": [False]} return header, [bytes(2**20)] def my_loads(header, frames): return pickle.loads(frames[0]) register_serialization_family("with-compression", my_dumps_compression, my_loads) register_serialization_family("no-compression", my_dumps_no_compression, my_loads) header, _ = serialize( [MyObjWithCompression(), MyObjWithNoCompression()], serializers=("with-compression", "no-compression"), on_error="raise", iterate_collection=True, ) assert header["compression"] == [True, False]