def test_serialize_large_numpy_arrays(self): x = SerializationContext({}) a = numpy.arange(100000000) a2 = x.deserialize(x.serialize(a)) self.assertTrue(numpy.all(a == a2))
def test_serialize_memoizes_tuples(self): ts = SerializationContext() l = (1, 2, 3) for i in range(100): l = (l, l) self.assertTrue(len(ts.serialize(l)) < (i + 1) * 100)
def test_serialize_dict_doesnt_leak(self): T = Dict(int, int) d = T({i: i+1 for i in range(100)}) x = SerializationContext({}) usage = currentMemUsageMb() for _ in range(20000): x.deserialize(x.serialize(d)) self.assertLess(currentMemUsageMb(), usage+1)
def test_serialize_dict(self): x = SerializationContext({}) d = Dict(str, str)() d["hi"] = "hi" d["a"] = "a" d2 = x.deserialize(x.serialize(d)) self.assertEqual(d, d2)
def test_serialize_named_tuple_subclass(self): class X(NamedTuple(x=int)): def f(self): return self.x ts = SerializationContext({'X': X}) self.assertIs(ping_pong(X, ts), X) self.assertTrue(ts.serialize(X(x=20)) != ts.serialize(X(x=21))) self.check_idempotence(X(x=20), ts)
def test_serialize_recursive_dict_more(self): D = Forward("D") D = D.define(Dict(str, OneOf(str, D))) x = SerializationContext({"D": D}) d = D() d["hi"] = "bye" d["recurses"] = d d2 = x.deserialize(x.serialize(d)) self.assertEqual(d2['recurses']['recurses']['hi'], 'bye')
def test_serialize_large_lists(self): x = SerializationContext({}) lst = ListOf(ListOf(int))() lst.resize(100) for sublist in lst: sublist.resize(1000000) t0 = time.time() l2 = x.deserialize(x.serialize(lst)) print(time.time() - t0, " to roundtrip") self.assertEqual(lst, l2)
def test_can_serialize_type_functions(self): @TypeFunction def List(T): ListT = Forward("ListT") return ListT.define( Alternative("List", Node={ "head": T, "tail": ListT }, Empty={})) context = SerializationContext({'List': List}) self.assertIs(context.deserialize(context.serialize(List(int))), List(int)) self.assertIsInstance( context.deserialize(context.serialize(List(int).Empty())), List(int)) list_of_int = List(int) list_of_list = List(list_of_int) l0 = list_of_int.Empty() l_l = list_of_list.Node(head=l0, tail=list_of_list.Empty()) self.assertEqual(context.deserialize(context.serialize(l_l)), l_l)
def test_serialize_alternatives_as_types(self): A = Forward("A") A = A.define(Alternative("A", X={'a': int}, Y={'a': A})) ts = SerializationContext({'A': A}) self.assertIs(ping_pong(A, ts), A) self.assertIs(ping_pong(A.X, ts), A.X)
def test_serialize_result_of_decorator(self): sc = SerializationContext({}) def decorator(f): def addsOne(x): return f(x) + 1 return addsOne @decorator def g(x): return x + 1 g2 = sc.deserialize(sc.serialize(g)) self.assertEqual(g2(10), g(10))
def test_serializing_named_tuples_in_loop(self): NT = NamedTuple(x=OneOf(int, float), y=OneOf(int, lambda: NT)) context = SerializationContext({'NT': NT}) self.serializeInLoop(lambda: NT(x=10, y=NT(x=20, y=2)), context=context)
def test_serializing_named_tuples_in_loop(self): NT = Forward("NT") NT = NT.define(NamedTuple(x=OneOf(int, float), y=OneOf(int, TupleOf(NT)))) context = SerializationContext({'NT': NT}) self.serializeInLoop(lambda: NT(x=10, y=(NT(x=20, y=2),)), context=context)
def test_serialize_is_parallel(self): if os.environ.get('TRAVIS_CI', None): return T = ListOf(int) x = T() x.resize(1000000) sc = SerializationContext({}).withoutCompression() def f(): for i in range(10): sc.deserialize(sc.serialize(x, T), T) ratios = [] for i in range(10): t0 = time.time() thread_apply(f, [()]) t1 = time.time() thread_apply(f, [(), ()]) t2 = time.time() first = t1 - t0 second = t2 - t1 ratios.append(second/first) ratios = sorted(ratios) ratio = ratios[5] # expect the ratio to be close to 1, but have some error margin self.assertTrue(ratio >= .8 and ratio < 1.2, ratios)
def test_serializing_alternatives_in_loop(self): AT = Forward("AT") AT = AT.define(Alternative("AT", X={'x': int, 'y': float}, Y={'x': int, 'y': AT})) context = SerializationContext({'AT': AT}) self.serializeInLoop(lambda: AT, context=context) self.serializeInLoop(lambda: AT.Y, context=context) self.serializeInLoop(lambda: AT.X(x=10, y=20), context=context)
def test_serializing_objects_in_loop(self): class X: def __init__(self, a=None, b=None, c=None): self.a = a self.b = b self.c = c c = SerializationContext({'X': X}) self.serializeInLoop(lambda: X(a=X(), b=[1, 2, 3], c=X(a=X())), context=c)
def test_serialize_and_numpy(self): x = numpy.ones(10000) ts = SerializationContext() self.assertTrue(numpy.all(x == ts.deserialize(ts.serialize(x)))) sizeCompressed = len(ts.serialize(x)) ts.compressionEnabled = False self.assertTrue(numpy.all(x == ts.deserialize(ts.serialize(x)))) sizeNotCompressed = len(ts.serialize(x)) self.assertTrue(sizeNotCompressed > sizeCompressed * 2, (sizeNotCompressed, sizeCompressed))
def ping_pong(obj, serialization_context=None): serialization_context = serialization_context or SerializationContext() s = serialization_context.withoutCompression().serialize(obj) try: return serialization_context.withoutCompression().deserialize(s) except Exception: print("FAILED TO DECODE:") print(s) print(pprint.PrettyPrinter(indent=2).pprint(decodeSerializedObject(s))) raise
def test_serialize_class_instance(self): class A: def __init__(self, x): self.x = x def f(self): return b"an embedded string" ts = SerializationContext({'A': A}) serialization = ts.serialize(A(10)) self.assertTrue(b'an embedded string' not in serialization) anA = ts.deserialize(serialization) self.assertEqual(anA.x, 10) anA2 = deserialize(A, serialize(A, A(10), ts), ts) self.assertEqual(anA2.x, 10)
def serializeInLoop(self, objectMaker, context=None): context = context or SerializationContext({}) memUsage = currentMemUsageMb() t0 = time.time() while time.time() - t0 < .25: data = context.serialize(objectMaker()) context.deserialize(data) gc.collect() self.assertLess(currentMemUsageMb() - memUsage, 1.0)
def test_serialize_recursive_object(self): class AnObject: def __init__(self, o): self.o = o ts = SerializationContext({'O': AnObject}) o = AnObject(None) o.o = o o2 = ping_pong(o, ts) self.assertIs(o2.o, o2)
def test_serialize_lambdas(self): sc = SerializationContext({}) with tempfile.TemporaryDirectory() as tf: fpath = os.path.join(tf, "weird_serialization_test.py") with open(fpath, "w") as f: f.write("def f(x):\n return x + 1\n") sys.path.append(tf) m = importlib.import_module('weird_serialization_test') #verify we can serialize this deserialized_f = sc.deserialize(sc.serialize(m.f)) self.assertEqual(deserialized_f(10), 11) assert not os.path.exists(fpath) ast_util.clearAllCaches() #at this point, the backing data for serialization is not there #and also, the cache is cleared. deserialized_f_2 = sc.deserialize(sc.serialize(deserialized_f)) self.assertEqual(deserialized_f_2(10), 11)
def test_serialize_objects(self): class AnObject: def __init__(self, o): self.o = o ts = SerializationContext({'O': AnObject}) o = AnObject(123) o2 = ping_pong(o, ts) self.assertIsInstance(o2, AnObject) self.assertEqual(o2.o, 123)
def test_serialization_context_queries(self): sc = SerializationContext({ 'X': False, 'Y': True, }) self.assertIs(sc.objectFromName('X'), False) self.assertIs(sc.nameForObject(False), 'X') self.assertIs(sc.objectFromName('Y'), True) self.assertIs(sc.nameForObject(True), 'Y')
def test_serialize_nested_names(self): global Nested class Nested: class A: class B: class C: pass sc = SerializationContext({ 'Nested': Nested, 'Nested.A': Nested.A, 'Nested.A.B': Nested.A.B, 'Nested.A.B.C': Nested.A.B.C }) for obj in [Nested.A, Nested.A.B, Nested.A.B.C]: with self.subTest(obj=obj): unpickled = ping_pong(obj, sc) self.assertIs(obj, unpickled)
def test_serialize_primitive_compound_types(self): class A: pass B = Alternative("B", X={'a': A}) ts = SerializationContext({'A': A, 'B': B}) for t in [ ConstDict(int, float), NamedTuple(x=int, y=str), TupleOf(bool), Tuple(int, int, bool), OneOf(int, float), OneOf(1, 2, 3, "hi", b"goodbye"), TupleOf(NamedTuple(x=int)), TupleOf(object), TupleOf(A), TupleOf(B) ]: self.assertIs(ping_pong(t, ts), t)
def test_serialize_and_threads(self): class A: def __init__(self, x): self.x = x ts = SerializationContext({'A': A}) OK = [] def thread(): t0 = time.time() while time.time() - t0 < 1.0: ping_pong(A(10), ts) OK.append(True) threads = [threading.Thread(target=thread) for _ in range(10)] for t in threads: t.start() for t in threads: t.join() self.assertEqual(len(OK), len(threads))
def test_inject_exception_into_context(self): NT = NamedTuple() context = SerializationContext({'NT': NT}) context2 = SerializationContext({'NT': NT}) def throws(*args): raise Exception("Test Exception") context.nameForObject = throws context2.objectFromName = throws with self.assertRaisesRegex(Exception, "Test Exception"): context.serialize(NT) data = context2.serialize(NT) with self.assertRaisesRegex(Exception, "Test Exception"): context2.deserialize(data)
def test_serializing_tuple_of_in_loop(self): TO = TupleOf(int) context = SerializationContext({'TO': TO}) self.serializeInLoop(lambda: TO((1, 2, 3, 4, 5)), context=context)
def test_bad_serialization_context(self): with self.assertRaises(AssertionError): ts = SerializationContext({'': int}) with self.assertRaises(AssertionError): ts = SerializationContext({b'': int})
def test_serialize_alternatives(self): A = Alternative("A", X={'a': int}, Y={'a': lambda: A}) ts = SerializationContext({'A': A}) self.assertIs(ping_pong(A.X, ts), A.X)