Example #1
0
def deserialize_numpy_ndarray(header, frames):
    with log_errors():
        if header.get("pickle"):
            return pickle.loads(frames[0], buffers=frames[1:])

        (frame, ) = frames
        (writeable, ) = header["writeable"]

        is_custom, dt = header["dtype"]
        if is_custom:
            dt = pickle.loads(dt)
        else:
            dt = np.dtype(dt)

        if header.get("broadcast_to"):
            shape = header["broadcast_to"]
        else:
            shape = header["shape"]

        x = np.ndarray(shape,
                       dtype=dt,
                       buffer=frame,
                       strides=header["strides"])
        if not writeable:
            x.flags.writeable = False
        else:
            x = np.require(x, requirements=["W"])

        return x
Example #2
0
def test_pickle_numpy():
    np = pytest.importorskip('numpy')
    x = np.ones(5)
    assert (loads(dumps(x)) == x).all()

    x = np.ones(5000)
    assert (loads(dumps(x)) == x).all()
def test_pickle_numpy():
    np = pytest.importorskip("numpy")
    x = np.ones(5)
    assert (loads(dumps(x)) == x).all()

    x = np.ones(5000)
    assert (loads(dumps(x)) == x).all()
Example #4
0
def test_pickle_numpy():
    np = pytest.importorskip("numpy")
    x = np.ones(5)
    assert (loads(dumps(x)) == x).all()
    assert (deserialize(*serialize(x, serializers=("pickle", ))) == x).all()

    x = np.ones(5000)
    assert (loads(dumps(x)) == x).all()
    assert (deserialize(*serialize(x, serializers=("pickle", ))) == x).all()

    if HIGHEST_PROTOCOL >= 5:
        x = np.ones(5000)

        l = []
        d = dumps(x, buffer_callback=l.append)
        assert len(l) == 1
        assert isinstance(l[0], PickleBuffer)
        assert memoryview(l[0]) == memoryview(x)
        assert (loads(d, buffers=l) == x).all()

        h, f = serialize(x, serializers=("pickle", ))
        assert len(f) == 2
        assert isinstance(f[0], bytes)
        assert isinstance(f[1], memoryview)
        assert (deserialize(h, f) == x).all()
Example #5
0
def test_worker(c, a, b):
    aa = rpc(ip=a.ip, port=a.port)
    bb = rpc(ip=b.ip, port=b.port)

    result = yield aa.identity()
    assert not a.active
    response = yield aa.compute(key='x',
                                function=dumps(add),
                                args=dumps([1, 2]),
                                who_has={},
                                close=True)
    assert not a.active
    assert response['status'] == 'OK'
    assert a.data['x'] == 3
    assert isinstance(response['compute_start'], float)
    assert isinstance(response['compute_stop'], float)
    assert isinstance(response['thread'], Integral)

    response = yield bb.compute(key='y',
                                function=dumps(add),
                                args=dumps(['x', 10]),
                                who_has={'x': [a.address]})
    assert response['status'] == 'OK'
    assert b.data['y'] == 13
    assert response['nbytes'] == sizeof(b.data['y'])
    assert isinstance(response['transfer_start'], float)
    assert isinstance(response['transfer_stop'], float)

    def bad_func():
        1 / 0

    response = yield bb.compute(key='z',
                                function=dumps(bad_func),
                                args=dumps(()),
                                close=True)
    assert not b.active
    assert response['status'] == 'error'
    assert isinstance(loads(response['exception']), ZeroDivisionError)
    if sys.version_info[0] >= 3:
        assert any('1 / 0' in line
                  for line in pluck(3, traceback.extract_tb(
                      loads(response['traceback'])))
                  if line)

    aa.close_rpc()
    yield a._close()

    assert a.address not in c.ncores and b.address in c.ncores

    assert list(c.ncores.keys()) == [b.address]

    assert isinstance(b.address, str)
    assert b.ip in b.address
    assert str(b.port) in b.address

    bb.close_rpc()
Example #6
0
def test_pickle_numpy(protocol):
    np = pytest.importorskip("numpy")
    context = {"pickle-protocol": protocol}

    x = np.ones(5)
    assert (loads(dumps(x, protocol=protocol)) == x).all()
    assert (
        deserialize(*serialize(x, serializers=("pickle",), context=context)) == x
    ).all()

    x = np.ones(5000)
    assert (loads(dumps(x, protocol=protocol)) == x).all()
    assert (
        deserialize(*serialize(x, serializers=("pickle",), context=context)) == x
    ).all()

    x = np.array([np.arange(3), np.arange(4, 6)], dtype=object)
    x2 = loads(dumps(x, protocol=protocol))
    assert x.shape == x2.shape
    assert x.dtype == x2.dtype
    assert x.strides == x2.strides
    for e_x, e_x2 in zip(x.flat, x2.flat):
        np.testing.assert_equal(e_x, e_x2)
    h, f = serialize(x, serializers=("pickle",), context=context)
    if protocol >= 5:
        assert len(f) == 3
    else:
        assert len(f) == 1
    x3 = deserialize(h, f)
    assert x.shape == x3.shape
    assert x.dtype == x3.dtype
    assert x.strides == x3.strides
    for e_x, e_x3 in zip(x.flat, x3.flat):
        np.testing.assert_equal(e_x, e_x3)

    if protocol >= 5:
        x = np.ones(5000)

        l = []
        d = dumps(x, protocol=protocol, buffer_callback=l.append)
        assert len(l) == 1
        assert isinstance(l[0], pickle.PickleBuffer)
        assert memoryview(l[0]) == memoryview(x)
        assert (loads(d, buffers=l) == x).all()

        h, f = serialize(x, serializers=("pickle",), context=context)
        assert len(f) == 2
        assert isinstance(f[0], bytes)
        assert isinstance(f[1], memoryview)
        assert (deserialize(h, f) == x).all()
Example #7
0
    async def plugin_add(self, plugin=None, name=None):
        with log_errors(pdb=False):
            if isinstance(plugin, bytes):
                plugin = pickle.loads(plugin)

            if name is None:
                name = _get_plugin_name(plugin)

            assert name

            self.plugins[name] = plugin

            logger.info("Starting Nanny plugin %s" % name)
            if hasattr(plugin, "setup"):
                try:
                    result = plugin.setup(nanny=self)
                    if isawaitable(result):
                        result = await result
                except Exception as e:
                    msg = error_message(e)
                    return msg
            if getattr(plugin, "restart", False):
                await self.restart()

            return {"status": "OK"}
Example #8
0
def test_pickle_functions():
    value = 1
    def f(x):  # closure
        return x + value

    for func in [f, lambda x: x + 1, partial(add, 1)]:
        assert loads(dumps(func))(1) == func(1)
Example #9
0
def test_pickle_functions(protocol):
    context = {"pickle-protocol": protocol}

    def make_closure():
        value = 1

        def f(x):  # closure
            return x + value

        return f

    def funcs():
        yield make_closure()
        yield (lambda x: x + 1)
        yield partial(add, 1)

    for func in funcs():
        wr = weakref.ref(func)

        func2 = loads(dumps(func, protocol=protocol))
        wr2 = weakref.ref(func2)
        assert func2(1) == func(1)

        func3 = deserialize(*serialize(func, serializers=("pickle",), context=context))
        wr3 = weakref.ref(func3)
        assert func3(1) == func(1)

        del func, func2, func3
        gc.collect()
        assert wr() is None
        assert wr2() is None
        assert wr3() is None
Example #10
0
def test_pickle_data(protocol):
    context = {"pickle-protocol": protocol}

    data = [1, b"123", "123", [123], {}, set()]
    for d in data:
        assert loads(dumps(d, protocol=protocol)) == d
        assert deserialize(*serialize(d, serializers=("pickle",), context=context)) == d
Example #11
0
def test_pickle_functions():
    value = 1

    def f(x):  # closure
        return x + value

    for func in [f, lambda x: x + 1, partial(add, 1)]:
        assert loads(dumps(func))(1) == func(1)
Example #12
0
 def deserialize(cls, serialized_flow: str) -> "Flow":
     import base64
     import zlib
     from distributed.protocol.pickle import loads
     compressed_flow = base64.decodebytes(serialized_flow.encode())
     initialized_flow: 'Flow' = loads(zlib.decompress(compressed_flow))
     assert initialized_flow.initialized
     return initialized_flow
Example #13
0
def test_pickle_out_of_band():
    class MemoryviewHolder:
        def __init__(self, mv):
            self.mv = memoryview(mv)

        def __reduce_ex__(self, protocol):
            if protocol >= 5:
                return MemoryviewHolder, (pickle.PickleBuffer(self.mv), )
            else:
                return MemoryviewHolder, (self.mv.tobytes(), )

    mv = memoryview(b"123")
    mvh = MemoryviewHolder(mv)

    if HIGHEST_PROTOCOL >= 5:
        l = []
        d = dumps(mvh, buffer_callback=l.append)
        mvh2 = loads(d, buffers=l)

        assert len(l) == 1
        assert isinstance(l[0], pickle.PickleBuffer)
        assert memoryview(l[0]) == mv
    else:
        mvh2 = loads(dumps(mvh))

    assert isinstance(mvh2, MemoryviewHolder)
    assert isinstance(mvh2.mv, memoryview)
    assert mvh2.mv == mv

    h, f = serialize(mvh, serializers=("pickle", ))
    mvh3 = deserialize(h, f)

    assert isinstance(mvh3, MemoryviewHolder)
    assert isinstance(mvh3.mv, memoryview)
    assert mvh3.mv == mv

    if HIGHEST_PROTOCOL >= 5:
        assert len(f) == 2
        assert isinstance(f[0], bytes)
        assert isinstance(f[1], memoryview)
        assert f[1] == mv
    else:
        assert len(f) == 1
        assert isinstance(f[0], bytes)
Example #14
0
        def transition(self, key, start, finish, *args, **kwargs):
            if finish in status_finish:
                function_name, emit_id, _ = key.split('--')
                ts = self.worker.tasks[key]
                exc = ts.exception
                trace = ts.traceback
                trace = traceback.format_tb(trace.data)
                error = str(exc.data)
                if 'SerializedTask' in str(type(ts.runspec)):
                    function = loads(ts.runspec.function)
                    args = ts.runspec.args
                    args = loads(args) if args else ()
                    kwargs = ts.runspec.kwargs
                    kwargs = loads(kwargs) if kwargs else ()
                else:
                    function, args, kwargs = ts.runspec
                function_name = f'{function.__module__}.{function.__name__}'

                data = {}
                for dep in ts.dependencies:
                    try:
                        data[dep.key] = self.worker.data[dep.key]
                    except BaseException:
                        pass

                args2 = pack_data(args, data, key_types=(bytes, str))
                kwargs2 = pack_data(kwargs, data, key_types=(bytes, str))

                data = {
                    'function': function_name,
                    'args': args2,
                    'kwargs': kwargs2,
                    'exception': trace,
                    'error': error,
                    'key': key,
                    'emit_id': emit_id
                }
                self.persist(data, datetime.now())
Example #15
0
def test_pickle_out_of_band(protocol):
    context = {"pickle-protocol": protocol}

    mv = memoryview(b"123")
    mvh = MemoryviewHolder(mv)

    if protocol >= 5:
        l = []
        d = dumps(mvh, protocol=protocol, buffer_callback=l.append)
        mvh2 = loads(d, buffers=l)

        assert len(l) == 1
        assert isinstance(l[0], pickle.PickleBuffer)
        assert memoryview(l[0]) == mv
    else:
        mvh2 = loads(dumps(mvh, protocol=protocol))

    assert isinstance(mvh2, MemoryviewHolder)
    assert isinstance(mvh2.mv, memoryview)
    assert mvh2.mv == mv

    h, f = serialize(mvh, serializers=("pickle",), context=context)
    mvh3 = deserialize(h, f)

    assert isinstance(mvh3, MemoryviewHolder)
    assert isinstance(mvh3.mv, memoryview)
    assert mvh3.mv == mv

    if protocol >= 5:
        assert len(f) == 2
        assert isinstance(f[0], bytes)
        assert isinstance(f[1], memoryview)
        assert f[1] == mv
    else:
        assert len(f) == 1
        assert isinstance(f[0], bytes)
Example #16
0
def deserialize_numpy_maskedarray(header, frames):
    data_header = header["data-header"]
    data_frames = frames[:header["nframes"]]
    data = deserialize_numpy_ndarray(data_header, data_frames)

    if "mask-header" in header:
        mask_header = header["mask-header"]
        mask_frames = frames[header["nframes"]:]
        mask = deserialize_numpy_ndarray(mask_header, mask_frames)
    else:
        mask = np.ma.nomask

    pickled_fv, fill_value = header["fill-value"]
    if pickled_fv:
        fill_value = pickle.loads(fill_value)

    return np.ma.masked_array(data, mask=mask, fill_value=fill_value)
Example #17
0
    def deserialize(self, header, frames):
        cls = pickle.loads(header["type-serialized"])
        if issubclass(cls, dict):
            dd = obj = {}
        else:
            obj = object.__new__(cls)
            dd = obj.__dict__
        dd.update(header["simple"])
        for k, d in header["complex"].items():
            h = d["header"]
            f = frames[d["start"] : d["stop"]]
            nested_dict = h.get("nested-dict")
            if nested_dict:
                v = self.deserialize(nested_dict, f)
            else:
                v = deserialize(h, f)
            dd[k] = v

        return obj
Example #18
0
def test_pickle_functions():
    def make_closure():
        value = 1

        def f(x):  # closure
            return x + value
        return f

    def funcs():
        yield make_closure()
        yield (lambda x: x + 1)
        yield partial(add, 1)

    for func in funcs():
        wr = weakref.ref(func)
        func2 = loads(dumps(func))
        wr2 = weakref.ref(func2)
        assert func2(1) == func(1)
        del func, func2
        assert wr() is None
        assert wr2() is None
Example #19
0
def test_pickle_functions():
    def make_closure():
        value = 1

        def f(x):  # closure
            return x + value

        return f

    def funcs():
        yield make_closure()
        yield (lambda x: x + 1)
        yield partial(add, 1)

    for func in funcs():
        wr = weakref.ref(func)
        func2 = loads(dumps(func))
        wr2 = weakref.ref(func2)
        assert func2(1) == func(1)
        del func, func2
        assert wr() is None
        assert wr2() is None
Example #20
0
def pickle_loads(header, frames):
    x, buffers = frames[0], frames[1:]

    writeable = header.get("writeable")
    if not writeable:
        writeable = len(buffers) * (None,)

    new = []
    memoryviews = map(memoryview, buffers)
    for w, mv in zip(writeable, memoryviews):
        if w == mv.readonly:
            if mv.readonly:
                mv = memoryview(bytearray(mv))
            else:
                mv = memoryview(bytes(mv))
            if mv.nbytes > 0:
                mv = mv.cast(mv.format, mv.shape)
            else:
                mv = mv.cast(mv.format)
        new.append(mv)

    return pickle.loads(x, buffers=new)
Example #21
0
        def _decode_default(obj):
            offset = obj.get("__Serialized__", 0)
            if offset > 0:
                sub_header = msgpack.loads(
                    frames[offset],
                    object_hook=msgpack_decode_default,
                    use_list=False,
                    **msgpack_opts,
                )
                offset += 1
                sub_frames = frames[offset:offset +
                                    sub_header["num-sub-frames"]]
                if deserialize:
                    if "compression" in sub_header:
                        sub_frames = decompress(sub_header, sub_frames)
                    return merge_and_deserialize(sub_header,
                                                 sub_frames,
                                                 deserializers=deserializers)
                else:
                    return Serialized(sub_header, sub_frames)

            offset = obj.get("__Pickled__", 0)
            if offset > 0:
                sub_header = msgpack.loads(frames[offset])
                offset += 1
                sub_frames = frames[offset:offset +
                                    sub_header["num-sub-frames"]]
                if allow_pickle:
                    return pickle.loads(sub_header["pickled-obj"],
                                        buffers=sub_frames)
                else:
                    raise ValueError(
                        "Unpickle on the Scheduler isn't allowed, set `distributed.scheduler.pickle=true`"
                    )

            return msgpack_decode_default(obj)
Example #22
0
 def func(dask_scheduler, plugin=None):
     p = loads(plugin)  # deserialize
     dask_scheduler.add_plugin(p)
     p.register_handlers(dask_scheduler)
Example #23
0
def serialize_numpy_ndarray(x, context=None):
    if x.dtype.hasobject or (x.dtype.flags & np.core.multiarray.LIST_PICKLE):
        header = {"pickle": True}
        frames = [None]
        buffer_callback = lambda f: frames.append(memoryview(f))
        frames[0] = pickle.dumps(
            x,
            buffer_callback=buffer_callback,
            protocol=(context or {}).get("pickle-protocol", None),
        )
        return header, frames

    # We cannot blindly pickle the dtype as some may fail pickling,
    # so we have a mixture of strategies.
    if x.dtype.kind == "V":
        # Preserving all the information works best when pickling
        try:
            # Only use stdlib pickle as cloudpickle is slow when failing
            # (microseconds instead of nanoseconds)
            dt = (
                1,
                pickle.pickle.dumps(x.dtype,
                                    protocol=(context or {}).get(
                                        "pickle-protocol", None)),
            )
            pickle.loads(dt[1])  # does it unpickle fine?
        except Exception:
            # dtype fails pickling => fall back on the descr if reasonable.
            if x.dtype.type is not np.void or x.dtype.alignment != 1:
                raise
            else:
                dt = (0, x.dtype.descr)
    else:
        dt = (0, x.dtype.str)

    # Only serialize broadcastable data for arrays with zero strided axes
    broadcast_to = None
    if 0 in x.strides:
        broadcast_to = x.shape
        strides = x.strides
        writeable = x.flags.writeable
        x = x[tuple(slice(None) if s != 0 else slice(1) for s in strides)]
        if not x.flags.c_contiguous and not x.flags.f_contiguous:
            # Broadcasting can only be done with contiguous arrays
            x = np.ascontiguousarray(x)
            x = np.lib.stride_tricks.as_strided(
                x,
                strides=[
                    j if i != 0 else i for i, j in zip(strides, x.strides)
                ],
                writeable=writeable,
            )

    if not x.shape:
        # 0d array
        strides = x.strides
        data = x.ravel()
    elif x.flags.c_contiguous or x.flags.f_contiguous:
        # Avoid a copy and respect order when unserializing
        strides = x.strides
        data = x.ravel(order="K")
    else:
        x = np.ascontiguousarray(x)
        strides = x.strides
        data = x.ravel()

    if data.dtype.fields or data.dtype.itemsize > 8:
        data = data.view("u%d" % math.gcd(x.dtype.itemsize, 8))

    try:
        data = data.data
    except ValueError:
        # "ValueError: cannot include dtype 'M' in a buffer"
        data = data.view("u%d" % math.gcd(x.dtype.itemsize, 8)).data

    header = {
        "dtype": dt,
        "shape": x.shape,
        "strides": strides,
        "writeable": [x.flags.writeable],
    }

    if broadcast_to is not None:
        header["broadcast_to"] = broadcast_to

    frames = [data]
    return header, frames
Example #24
0
def test_worker_bad_args(c, a, b):
    aa = rpc(ip=a.ip, port=a.port)
    bb = rpc(ip=b.ip, port=b.port)

    class NoReprObj(object):
        """ This object cannot be properly represented as a string. """
        def __str__(self):
            raise ValueError("I have no str representation.")
        def __repr__(self):
            raise ValueError("I have no repr representation.")

    response = yield aa.compute(key='x',
                                function=dumps(NoReprObj),
                                args=dumps(()),
                                who_has={})
    assert not a.active
    assert response['status'] == 'OK'
    assert a.data['x']
    assert isinstance(response['compute_start'], float)
    assert isinstance(response['compute_stop'], float)
    assert isinstance(response['thread'], Integral)

    def bad_func(*args, **kwargs):
        1 / 0

    class MockLoggingHandler(logging.Handler):
        """Mock logging handler to check for expected logs."""

        def __init__(self, *args, **kwargs):
            self.reset()
            logging.Handler.__init__(self, *args, **kwargs)

        def emit(self, record):
            self.messages[record.levelname.lower()].append(record.getMessage())

        def reset(self):
            self.messages = {
                'debug': [],
                'info': [],
                'warning': [],
                'error': [],
                'critical': [],
            }

    hdlr = MockLoggingHandler()
    old_level = logger.level
    logger.setLevel(logging.DEBUG)
    logger.addHandler(hdlr)
    response = yield bb.compute(key='y',
                                function=dumps(bad_func),
                                args=dumps(['x']),
                                kwargs=dumps({'k': 'x'}),
                                who_has={'x': [a.address]})
    assert not b.active
    assert response['status'] == 'error'
    # Make sure job died because of bad func and not because of bad
    # argument.
    assert isinstance(loads(response['exception']), ZeroDivisionError)
    if sys.version_info[0] >= 3:
        assert any('1 / 0' in line
                  for line in pluck(3, traceback.extract_tb(
                      loads(response['traceback'])))
                  if line)
    assert hdlr.messages['warning'][0] == " Compute Failed\n" \
        "Function: bad_func\n" \
        "args:     (< could not convert arg to str >)\n" \
        "kwargs:   {'k': < could not convert arg to str >}\n"
    assert re.match(r"^Send compute response to scheduler: y, " \
        "\{.*'args': \(< could not convert arg to str >\), .*" \
        "'kwargs': \{'k': < could not convert arg to str >\}.*\}",
        hdlr.messages['debug'][0]) or \
        re.match("^Send compute response to scheduler: y, " \
        "\{.*'kwargs': \{'k': < could not convert arg to str >\}, .*" \
        "'args': \(< could not convert arg to str >\).*\}",
        hdlr.messages['debug'][0])
    logger.setLevel(old_level)

    # Now we check that both workers are still alive.

    assert not a.active
    response = yield aa.compute(key='z',
                                function=dumps(add),
                                args=dumps([1, 2]),
                                who_has={},
                                close=True)
    assert not a.active
    assert response['status'] == 'OK'
    assert a.data['z'] == 3
    assert isinstance(response['compute_start'], float)
    assert isinstance(response['compute_stop'], float)
    assert isinstance(response['thread'], Integral)

    assert not b.active
    response = yield bb.compute(key='w',
                                function=dumps(add),
                                args=dumps([1, 2]),
                                who_has={},
                                close=True)
    assert not b.active
    assert response['status'] == 'OK'
    assert b.data['w'] == 3
    assert isinstance(response['compute_start'], float)
    assert isinstance(response['compute_stop'], float)
    assert isinstance(response['thread'], Integral)

    aa.close_rpc()
    bb.close_rpc()
def test_pickle_data():
    data = [1, b"123", "123", [123], {}, set()]
    for d in data:
        assert loads(dumps(d)) == d
Example #26
0
def test_pickle_data():
    data = [1, b'123', '123', [123], {}, set()]
    for d in data:
        assert loads(dumps(d)) == d
Example #27
0
def test_pickle_data():
    data = [1, b"123", "123", [123], {}, set()]
    for d in data:
        assert loads(dumps(d)) == d
        assert deserialize(*serialize(d, serializers=("pickle", ))) == d
Example #28
0
def dask_loads(header, frames):
    typ = pickle.loads(header["type-serialized"])
    loads = dask_deserialize.dispatch(typ)
    return loads(header["sub-header"], frames)