Пример #1
0
    def check_out(deserialize_flag, out_value):
        # Check output with deserialize=False
        assert sorted(out_value) == sorted(msg_orig)
        out_value = out_value.copy()  # in case transport passed the object as-is
        to_ser = out_value.pop('to_ser')
        ser = out_value.pop('ser')
        expected_msg = msg_orig.copy()
        del expected_msg['ser']
        del expected_msg['to_ser']
        assert out_value == expected_msg

        if deserialize_flag:
            assert isinstance(ser, (bytes, bytearray))
            assert bytes(ser) == _uncompressible
        else:
            assert isinstance(ser, Serialized)
            assert deserialize(ser.header, ser.frames) == _uncompressible
            assert isinstance(to_ser, list)
            to_ser, = to_ser
            # The to_serialize() value could have been actually serialized
            # or not (it's a transport-specific optimization)
            if isinstance(to_ser, Serialized):
                assert deserialize(to_ser.header, to_ser.frames) == _uncompressible
            else:
                assert to_ser == to_serialize(_uncompressible)
Пример #2
0
def test_raise_error_on_serialize_write_permissions():
    with tmpfile() as fn:
        with h5py.File(fn, mode='a') as f:
            x = f.create_dataset('/x', shape=(2, 2), dtype='i4')
            f.flush()
            with pytest.raises(TypeError):
                deserialize(*serialize(x))
            with pytest.raises(TypeError):
                deserialize(*serialize(f))
Пример #3
0
def test_serialize_raises():
    class Foo(object):
        pass

    @dask_serialize.register(Foo)
    def dumps(f):
        raise Exception("Hello-123")

    with pytest.raises(Exception) as info:
        deserialize(*serialize(Foo()))

    assert 'Hello-123' in str(info.value)
Пример #4
0
def test_dumps_serialize_numpy(df):
    header, frames = serialize(df)
    if 'compression' in header:
        frames = decompress(header, frames)
    df2 = deserialize(header, frames)

    assert_eq(df, df2)
Пример #5
0
def test_dumps_serialize_numpy(x):
    header, frames = serialize(x)
    if 'compression' in header:
        frames = decompress(header, frames)
    y = deserialize(header, frames)

    np.testing.assert_equal(x, y)
Пример #6
0
def test_dumps_serialize():
    for x in [123, [1, 2, 3]]:
        header, frames = serialize(x)
        assert not header
        assert len(frames) == 1

        result = deserialize(header, frames)
        assert result == x

    x = MyObj(123)
    header, frames = serialize(x)
    assert header['type']
    assert len(frames) == 1

    result = deserialize(header, frames)
    assert result.data == x.data
Пример #7
0
def test_serialize():
    x = np.ones((5, 5))
    header, frames = serialize(x)
    assert header['type']
    assert len(frames) == 1

    if 'compression' in header:
        frames = decompress(header, frames)
    result = deserialize(header, frames)
    assert (result == x).all()
Пример #8
0
def test_serialize_masked_series():
    nelem = 50
    data = np.random.random(nelem)
    mask = utils.random_bitmask(nelem)
    bitmask = utils.expand_bits_to_bytes(mask)[:nelem]
    null_count = utils.count_zero(bitmask)
    assert null_count >= 0
    sr = cudf.Series.from_masked_array(data, mask, null_count=null_count)
    outsr = deserialize(*serialize(sr))
    pd.util.testing.assert_series_equal(sr.to_pandas(), outsr.to_pandas())
Пример #9
0
def test_serialize_deserialize_variable():
    with tmpfile() as fn:
        create_test_dataset(fn)
        with netCDF4.Dataset(fn, mode='r') as f:
            x = f.variables['x']
            y = deserialize(*serialize(x))
            assert isinstance(y, netCDF4.Variable)
            assert y.dimensions == ('x',)
            assert (x.dtype == y.dtype)
            assert (x[:] == y[:]).all()
Пример #10
0
def test_serialize_deserialize_sparse_large():
    n = 100000000
    x = np.arange(n)
    data = np.ones(n, dtype=np.int16)

    s = sparse.COO([x], data)

    header, frames = serialize(s)

    s2 = deserialize(header, frames)
Пример #11
0
def test_serialize():
    x = np.ones((5, 5))
    header, frames = serialize(x)
    assert header["type"]
    assert len(frames) == 1

    if "compression" in header:
        frames = decompress(header, frames)
    result = deserialize(header, frames)
    assert (result == x).all()
Пример #12
0
def test_zero_strided_numpy_array(x, writeable):
    assert 0 in x.strides
    x.setflags(write=writeable)
    header, frames = serialize(x)
    y = deserialize(header, frames)
    np.testing.assert_equal(x, y)
    # Ensure we transmit fewer bytes than the full array
    assert sum(map(nbytes, frames)) < x.nbytes
    # Ensure both x and y are have same write flag
    assert x.flags.writeable == y.flags.writeable
Пример #13
0
def test_serialize_deserialize_variable():
    with tmpfile() as fn:
        create_test_dataset(fn)
        with netCDF4.Dataset(fn, mode="r") as f:
            x = f.variables["x"]
            y = deserialize(*serialize(x))
            assert isinstance(y, netCDF4.Variable)
            assert y.dimensions == ("x", )
            assert x.dtype == y.dtype
            assert (x[:] == y[:]).all()
Пример #14
0
def test_serialize_groupby():
    df = cudf.DataFrame()
    df['key'] = np.random.randint(0, 20, 100)
    df['val'] = np.arange(100, dtype=np.float32)
    gb = df.groupby('key')
    outgb = deserialize(*serialize(gb))

    got = gb.mean()
    expect = outgb.mean()
    pd.util.testing.assert_frame_equal(got.to_pandas(), expect.to_pandas())
Пример #15
0
def test_serialize_deserialize_sparse_large():
    n = 100000000
    x = np.arange(n)
    data = np.ones(n, dtype=np.int16)

    s = sparse.COO([x], data)

    header, frames = serialize(s)

    s2 = deserialize(header, frames)
Пример #16
0
def test_serialize_cupy_collection(collection, length, value):
    # Avoid running test for length 0 (no collection) multiple times
    if length == 0 and collection is not list:
        return

    if isinstance(value, dict):
        cudf = pytest.importorskip("cudf")
        dd = pytest.importorskip("dask.dataframe")
        x = cudf.DataFrame(value)
        assert_func = dd.assert_eq
    else:
        x = cupy.arange(10)
        assert_func = assert_eq

    if length == 0:
        obj = device_to_host(x)
    elif collection is dict:
        obj = device_to_host(dict(zip(range(length), (x, ) * length)))
    else:
        obj = device_to_host(collection((x, ) * length))

    if length > 0:
        assert all(
            [h["serializer"] == "dask" for h in obj.header["sub-headers"]])
    else:
        assert obj.header["serializer"] == "dask"

    btslst = serialize_bytelist(obj)

    bts = deserialize_bytes(b"".join(btslst))
    res = host_to_device(bts)

    if length == 0:
        assert_func(res, x)
    else:
        assert isinstance(res, collection)
        values = res.values() if collection is dict else res
        [assert_func(v, x) for v in values]

    header, frames = serialize(obj, serializers=["pickle"], on_error="raise")

    if HIGHEST_PROTOCOL >= 5:
        assert len(frames) == (1 + len(obj.frames))
    else:
        assert len(frames) == 1

    obj2 = deserialize(header, frames)
    res = host_to_device(obj2)

    if length == 0:
        assert_func(res, x)
    else:
        assert isinstance(res, collection)
        values = res.values() if collection is dict else res
        [assert_func(v, x) for v in values]
Пример #17
0
def test_serialize_deserialize_dataset():
    with tmpfile() as fn:
        with h5py.File(fn, mode='a') as f:
            x = f.create_dataset('/group1/group2/x', shape=(2, 2), dtype='i4')
        with h5py.File(fn, mode='r') as f:
            x = f['group1/group2/x']
            y = deserialize(*serialize(x))
            assert isinstance(y, h5py.Dataset)
            assert x.name == y.name
            assert x.file.filename == y.file.filename
            assert (x[:] == y[:]).all()
Пример #18
0
def test_serialize_deserialize_dataset():
    with tmpfile() as fn:
        create_test_dataset(fn)
        with netCDF4.Dataset(fn, mode='r') as f:
            g = deserialize(*serialize(f))
            assert f.filepath() == g.filepath()
            assert isinstance(g, netCDF4.Dataset)

            assert g.variables['x'].dimensions == ('x', )
            assert g.variables['x'].dtype == np.int32
            assert (g.variables['x'][:] == np.arange(3)).all()
Пример #19
0
def test_basic():
    est = sklearn.linear_model.LinearRegression()
    est.fit([[0, 0], [1, 1], [2, 2]], [0, 1, 2])

    header, frames = serialize(est)
    assert header['serializer'] == 'dask'

    est2 = deserialize(header, frames)

    inp = [[2, 3], [-1, 3]]
    assert (est.predict(inp) == est2.predict(inp)).all()
Пример #20
0
def test_basic():
    est = sklearn.linear_model.LinearRegression()
    est.fit([[0, 0], [1, 1], [2, 2]], [0, 1, 2])

    header, frames = serialize(est)
    assert header["serializer"] == "dask"

    est2 = deserialize(header, frames)

    inp = [[2, 3], [-1, 3]]
    assert (est.predict(inp) == est2.predict(inp)).all()
Пример #21
0
def test_serialize_datetime():
    # Make frame with datetime column
    df = pd.DataFrame({'x': np.random.randint(0, 5, size=20),
                       'y': np.random.normal(size=20)})
    ts = np.arange(0, len(df), dtype=np.dtype('datetime64[ms]'))
    df['timestamp'] = ts
    gdf = cudf.DataFrame.from_pandas(df)
    # (De)serialize roundtrip
    recreated = deserialize(*serialize(gdf))
    # Check
    pd.util.testing.assert_frame_equal(recreated.to_pandas(), df)
Пример #22
0
def test_serialize_deserialize_file():
    with tmpfile() as fn:
        with h5py.File(fn, mode='a') as f:
            f.create_dataset('/x', shape=(2, 2), dtype='i4')
        with h5py.File(fn, mode='r') as f:
            g = deserialize(*serialize(f))
            assert f.filename == g.filename
            assert isinstance(g, h5py.File)
            assert f.mode == g.mode

            assert g['x'].shape == (2, 2)
Пример #23
0
def test_serialize_deserialize_file():
    with tmpfile() as fn:
        with h5py.File(fn, mode='a') as f:
            f.create_dataset('/x', shape=(2, 2), dtype='i4')
        with h5py.File(fn, mode='r') as f:
            g = deserialize(*serialize(f))
            assert f.filename == g.filename
            assert isinstance(g, h5py.File)
            assert f.mode == g.mode

            assert g['x'].shape == (2, 2)
Пример #24
0
def test_serialize_deserialize_dataset():
    with tmpfile() as fn:
        with h5py.File(fn, mode='a') as f:
            x = f.create_dataset('/group1/group2/x', shape=(2, 2), dtype='i4')
        with h5py.File(fn, mode='r') as f:
            x = f['group1/group2/x']
            y = deserialize(*serialize(x))
            assert isinstance(y, h5py.Dataset)
            assert x.name == y.name
            assert x.file.filename == y.file.filename
            assert (x[:] == y[:]).all()
Пример #25
0
def test_serialize_string():
    # Make frame with string column
    df = pd.DataFrame({'x': np.random.randint(0, 5, size=5),
                       'y': np.random.normal(size=5)})
    str_data = ['a', 'bc', 'def', 'ghij', 'klmno']
    df['timestamp'] = str_data
    gdf = cudf.DataFrame.from_pandas(df)
    # (De)serialize roundtrip
    recreated = deserialize(*serialize(gdf))
    # Check
    pd.util.testing.assert_frame_equal(recreated.to_pandas(), df)
Пример #26
0
def test_serialize_deserialize_dataset():
    with tmpfile() as fn:
        create_test_dataset(fn)
        with netCDF4.Dataset(fn, mode='r') as f:
            g = deserialize(*serialize(f))
            assert f.filepath() == g.filepath()
            assert isinstance(g, netCDF4.Dataset)

            assert g.variables['x'].dimensions == ('x',)
            assert g.variables['x'].dtype == np.int32
            assert (g.variables['x'][:] == np.arange(3)).all()
Пример #27
0
def test_serialize_deserialize_dataset():
    with tmpfile() as fn:
        with h5py.File(fn, mode="a") as f:
            x = f.create_dataset("/group1/group2/x", shape=(2, 2), dtype="i4")
        with h5py.File(fn, mode="r") as f:
            x = f["group1/group2/x"]
            y = deserialize(*serialize(x))
            assert isinstance(y, h5py.Dataset)
            assert x.name == y.name
            assert x.file.filename == y.file.filename
            assert (x[:] == y[:]).all()
Пример #28
0
def test_serialize_deserialize_file():
    with tmpfile() as fn:
        with h5py.File(fn, mode="a") as f:
            f.create_dataset("/x", shape=(2, 2), dtype="i4")
        with h5py.File(fn, mode="r") as f:
            g = deserialize(*serialize(f))
            assert f.filename == g.filename
            assert isinstance(g, h5py.File)
            assert f.mode == g.mode

            assert g["x"].shape == (2, 2)
Пример #29
0
def test_masked_array_serialize():
    data = (5, 6)
    mask = [True, False]
    fill_value = 999
    x = np.ma.masked_array(data, mask=mask, fill_value=fill_value)
    header, frames = serialize(x)
    y = deserialize(header, frames)

    # Explicitly test the particular elements of the masked array.
    np.testing.assert_equal(data, y.data)
    np.testing.assert_equal(mask, y.mask)
    assert fill_value == y.fill_value
Пример #30
0
def test_dumps_serialize_numpy(x):
    header, frames = serialize(x)
    if "compression" in header:
        frames = decompress(header, frames)
    buffer_interface = memoryview
    for frame in frames:
        assert isinstance(frame, (bytes, buffer_interface))
    y = deserialize(header, frames)

    np.testing.assert_equal(x, y)
    if x.flags.c_contiguous or x.flags.f_contiguous:
        assert x.strides == y.strides
Пример #31
0
def test_serialize_deserialize_group():
    with tmpfile() as fn:
        with h5py.File(fn, mode='a') as f:
            f.create_dataset('/group1/group2/x', shape=(2, 2), dtype='i4')
        with h5py.File(fn, mode='r') as f:
            group = f['/group1/group2']
            group2 = deserialize(*serialize(group))

            assert isinstance(group2, h5py.Group)
            assert group.file.filename == group2.file.filename

            assert group2['x'].shape == (2, 2)
Пример #32
0
def test_serialize_cupy(shape, dtype, order, serializers):
    x = cupy.arange(numpy.product(shape), dtype=dtype)
    x = cupy.ndarray(shape, dtype=x.dtype, memptr=x.data, order=order)
    header, frames = serialize(x, serializers=serializers)
    y = deserialize(header, frames, deserializers=serializers)

    if serializers[0] == "cuda":
        assert all(hasattr(f, "__cuda_array_interface__") for f in frames)
    elif serializers[0] == "dask":
        assert all(isinstance(f, memoryview) for f in frames)

    assert (x == y).all()
Пример #33
0
def test_serialize_deserialize_sparse():
    x = np.random.random((2, 3, 4, 5))
    x[x < 0.8] = 0

    y = sparse.COO(x)
    header, frames = serialize(y)
    assert 'sparse' in header['type']
    z = deserialize(*serialize(y))

    assert_allclose(y.data, z.data)
    assert_allclose(y.coords, z.coords)
    assert_allclose(y.todense(), z.todense())
Пример #34
0
def test_dumps_serialize_numpy(x):
    header, frames = serialize(x)
    if 'compression' in header:
        frames = decompress(header, frames)
    buffer_interface = buffer if PY2 else memoryview  # noqa: F821
    for frame in frames:
        assert isinstance(frame, (bytes, buffer_interface))
    y = deserialize(header, frames)

    np.testing.assert_equal(x, y)
    if x.flags.c_contiguous or x.flags.f_contiguous:
        assert x.strides == y.strides
Пример #35
0
    def check_out_false(out_value):
        # Check output with deserialize=False
        out_value = out_value.copy()  # in case transport passed the object as-is
        to_ser = out_value.pop('to_ser')
        ser = out_value.pop('ser')
        expected_msg = msg_orig.copy()
        del expected_msg['ser']
        del expected_msg['to_ser']
        assert out_value == expected_msg

        assert isinstance(ser, Serialized)
        assert deserialize(ser.header, ser.frames) == 456

        assert isinstance(to_ser, list)
        to_ser, = to_ser
        # The to_serialize() value could have been actually serialized
        # or not (it's a transport-specific optimization)
        if isinstance(to_ser, Serialized):
            assert deserialize(to_ser.header, to_ser.frames) == 123
        else:
            assert to_ser == to_serialize(123)
Пример #36
0
def test_serialize_deserialize_group():
    with tmpfile() as fn:
        with h5py.File(fn, mode='a') as f:
            f.create_dataset('/group1/group2/x', shape=(2, 2), dtype='i4')
        with h5py.File(fn, mode='r') as f:
            group = f['/group1/group2']
            group2 = deserialize(*serialize(group))

            assert isinstance(group2, h5py.Group)
            assert group.file.filename == group2.file.filename

            assert group2['x'].shape == (2, 2)
Пример #37
0
def test_serialize_deserialize_group():
    with tmpfile() as fn:
        with h5py.File(fn, mode="a") as f:
            f.create_dataset("/group1/group2/x", shape=(2, 2), dtype="i4")
        with h5py.File(fn, mode="r") as f:
            group = f["/group1/group2"]
            group2 = deserialize(*serialize(group))

            assert isinstance(group2, h5py.Group)
            assert group.file.filename == group2.file.filename

            assert group2["x"].shape == (2, 2)
Пример #38
0
    def check_out_false(out_value):
        # Check output with deserialize=False
        out_value = out_value.copy()  # in case transport passed the object as-is
        to_ser = out_value.pop('to_ser')
        ser = out_value.pop('ser')
        expected_msg = msg_orig.copy()
        del expected_msg['ser']
        del expected_msg['to_ser']
        assert out_value == expected_msg

        assert isinstance(ser, Serialized)
        assert deserialize(ser.header, ser.frames) == 456

        assert isinstance(to_ser, list)
        to_ser, = to_ser
        # The to_serialize() value could have been actually serialized
        # or not (it's a transport-specific optimization)
        if isinstance(to_ser, Serialized):
            assert deserialize(to_ser.header, to_ser.frames) == 123
        else:
            assert to_ser == to_serialize(123)
Пример #39
0
def test_serialize_scipy_sparse(sparse_type, dtype):
    a = numpy.array([[0, 1, 0], [2, 0, 3], [0, 4, 0]], dtype=dtype)

    anz = a.nonzero()
    acoo = scipy_sparse.coo_matrix((a[anz], anz))
    asp = sparse_type(acoo)

    header, frames = serialize(asp, serializers=["dask"])
    asp2 = deserialize(header, frames)

    a2 = asp2.todense()

    assert (a == a2).all()
Пример #40
0
def test_serialize_cupy(dtype):
    ary = np.arange(100, dtype=dtype)
    x = cuda.to_device(ary)
    header, frames = serialize(x, serializers=("cuda", "dask", "pickle"))
    y = deserialize(header,
                    frames,
                    deserializers=("cuda", "dask", "pickle", "error"))

    hx = np.empty_like(ary)
    hy = np.empty_like(ary)
    x.copy_to_host(hx)
    y.copy_to_host(hy)
    assert (hx == hy).all()
Пример #41
0
def test_dumps_serialize_numpy_custom_dtype():
    from six.moves import builtins
    test_rational = pytest.importorskip('numpy.core.test_rational')
    rational = test_rational.rational
    try:
        builtins.rational = rational  # Work around https://github.com/numpy/numpy/issues/9160
        x = np.array([1], dtype=rational)
        header, frames = serialize(x)
        y = deserialize(header, frames)

        np.testing.assert_equal(x, y)
    finally:
        del builtins.rational
Пример #42
0
def test_memmap():
    with tmpfile('npy') as fn:
        with open(fn, 'wb') as f:  # touch file
            pass
        x = np.memmap(fn, shape=(5, 5), dtype='i4', mode='readwrite')
        x[:] = 5

        header, frames = serialize(x)
        if 'compression' in header:
            frames = decompress(header, frames)
        y = deserialize(header, frames)

        np.testing.assert_equal(x, y)
Пример #43
0
def test_serialize_cupy_from_numba(dtype):
    numba = pytest.importorskip("numba")
    np = pytest.importorskip("numpy")

    size = 10
    x_np = np.arange(size, dtype=dtype)
    x = numba.cuda.to_device(x_np)
    header, frames = serialize(x, serializers=("cuda", "dask", "pickle"))
    header["type-serialized"] = pickle.dumps(cupy.ndarray)

    y = deserialize(header, frames, deserializers=("cuda", "dask", "pickle", "error"))

    assert (x_np == cupy.asnumpy(y)).all()
Пример #44
0
def test_dumps_serialize_numpy_custom_dtype():
    from six.moves import builtins
    test_rational = pytest.importorskip('numpy.core.test_rational')
    rational = test_rational.rational
    try:
        builtins.rational = rational  # Work around https://github.com/numpy/numpy/issues/9160
        x = np.array([1], dtype=rational)
        header, frames = serialize(x)
        y = deserialize(header, frames)

        np.testing.assert_equal(x, y)
    finally:
        del builtins.rational
Пример #45
0
def test_serialize_deserialize_group():
    with tmpfile() as fn:
        create_test_dataset(fn)
        with netCDF4.Dataset(fn, mode='r') as f:
            for path in ['group', 'group/group1']:
                g = f[path]
                h = deserialize(*serialize(g))
                assert isinstance(h, netCDF4.Group)
                assert h.name == g.name
                assert list(g.groups) == list(h.groups)
                assert list(g.variables) == list(h.variables)

            vars = [f.variables['x'],
                    f['group'].variables['y'],
                    f['group/group1'].variables['z']]

            for x in vars:
                y = deserialize(*serialize(x))
                assert isinstance(y, netCDF4.Variable)
                assert y.dimensions == ('x',)
                assert (x.dtype == y.dtype)
                assert (x[:] == y[:]).all()
Пример #46
0
def test_memmap():
    with tmpfile("npy") as fn:
        with open(fn, "wb") as f:  # touch file
            pass
        x = np.memmap(fn, shape=(5, 5), dtype="i4", mode="readwrite")
        x[:] = 5

        header, frames = serialize(x)
        if "compression" in header:
            frames = decompress(header, frames)
        y = deserialize(header, frames)

        np.testing.assert_equal(x, y)
Пример #47
0
def test_serialize_deserialize_group():
    with tmpfile() as fn:
        create_test_dataset(fn)
        with netCDF4.Dataset(fn, mode='r') as f:
            for path in ['group', 'group/group1']:
                g = f[path]
                h = deserialize(*serialize(g))
                assert isinstance(h, netCDF4.Group)
                assert h.name == g.name
                assert list(g.groups) == list(h.groups)
                assert list(g.variables) == list(h.variables)

            vars = [f.variables['x'],
                    f['group'].variables['y'],
                    f['group/group1'].variables['z']]

            for x in vars:
                y = deserialize(*serialize(x))
                assert isinstance(y, netCDF4.Variable)
                assert y.dimensions == ('x',)
                assert (x.dtype == y.dtype)
                assert (x[:] == y[:]).all()
Пример #48
0
def test_memmap():
    with tmpfile('npy') as fn:
        with open(fn, 'wb') as f:  # touch file
            pass
        x = np.memmap(fn, shape=(5, 5), dtype='i4', mode='readwrite')
        x[:] = 5

        header, frames = serialize(x)
        if 'compression' in header:
            frames = decompress(header, frames)
        y = deserialize(header, frames)

        np.testing.assert_equal(x, y)
Пример #49
0
def test_serialize_numba(dtype):
    if not cuda.is_available():
        pytest.skip("CUDA is not available")

    ary = np.arange(100, dtype=dtype)
    x = cuda.to_device(ary)
    header, frames = serialize(x, serializers=("cuda", "dask", "pickle"))
    y = deserialize(header, frames, deserializers=("cuda", "dask", "pickle", "error"))

    hx = np.empty_like(ary)
    hy = np.empty_like(ary)
    x.copy_to_host(hx)
    y.copy_to_host(hy)
    assert (hx == hy).all()
Пример #50
0
def test_malicious_exception():
    class BadException(Exception):
        def __setstate__(self):
            return Exception("Sneaky deserialization code")

    class MyClass(object):
        def __getstate__(self):
            raise BadException()

    obj = MyClass()

    header, frames = serialize(obj, serializers=[])
    with pytest.raises(Exception) as info:
        deserialize(header, frames)

    assert "Sneaky" not in str(info.value)
    assert "MyClass" in str(info.value)

    header, frames = serialize(obj, serializers=['pickle'])
    with pytest.raises(Exception) as info:
        deserialize(header, frames)

    assert "Sneaky" not in str(info.value)
    assert "BadException" in str(info.value)
Пример #51
0
def test_serialize_deserialize_model():
    model = keras.models.Sequential()
    model.add(keras.layers.Dense(5, input_dim=3))
    model.add(keras.layers.Dense(2))
    model.compile(optimizer='sgd', loss='mse')
    x = np.random.random((1, 3))
    y = np.random.random((1, 2))
    model.train_on_batch(x, y)

    loaded = deserialize(*serialize(model))
    assert_allclose(loaded.predict(x), model.predict(x))

    data = {'model': to_serialize(model)}
    frames = dumps(data)
    result = loads(frames)
    assert_allclose(result['model'].predict(x), model.predict(x))
Пример #52
0
def test_roundtrip(obj):
    # Test that the serialize/deserialize functions actually
    # work independent of distributed
    header, frames = serialize(obj)
    new_obj = deserialize(header, frames)
    assert obj.equals(new_obj)
Пример #53
0
def test_serialize_bytestrings():
    for b in (b'123', bytearray(b'4567')):
        header, frames = serialize(b)
        assert frames[0] is b
        bb = deserialize(header, frames)
        assert bb == b
Пример #54
0
def test_empty():
    e = Empty()
    e2 = deserialize(*serialize(e))
    assert isinstance(e2, Empty)