def test_call_and_repr(func) -> None: global_state = {} x = evaluate_recipe(BASE_RECIPE, length=10, global_state=global_state) kwargs = dict(foo=42, bar=23) np.random.seed(0) ret = func( x, field_name="bar", length=10, global_state=global_state.copy(), **kwargs, ) func_reconstructed = load_code(dump_code(func)) np.random.seed(0) ret2 = func_reconstructed( x, field_name="foo", length=10, global_state=global_state.copy(), **kwargs, ) np.testing.assert_allclose(ret2, ret)
def test_code_serialization(e) -> None: expected, actual = e, serde.load_code(serde.dump_code(e)) assert check_equality(expected, actual)
[ mx.nd.random.uniform(shape=(3, 5, 2), dtype="float16"), mx.nd.random.uniform(shape=(3, 5, 2), dtype="float32"), mx.nd.random.uniform(shape=(3, 5, 2), dtype="float64"), mx.nd.array([[1, 2, 3], [-1, -2, 0]], dtype=np.uint8), mx.nd.array([[1, 2, 3], [-1, -2, 0]], dtype=np.int32), mx.nd.array([[1, 2, 3], [-1, -2, 0]], dtype=np.int64), mx.nd.array([[1, 2, 3], [1, 2, 0]], dtype=np.uint8), ], ) @pytest.mark.parametrize( "serialize_fn", [ lambda x: serde.load_json(serde.dump_json(x)), lambda x: serde.load_binary(serde.dump_binary(x)), lambda x: serde.load_code(serde.dump_code(x)), ], ) def test_ndarray_serialization(a, serialize_fn) -> None: b = serialize_fn(a) assert type(a) == type(b) assert a.dtype == b.dtype assert a.shape == b.shape assert np.all((a == b).asnumpy()) def test_timestamp_encode_decode() -> None: now = pd.Timestamp.now() assert now == serde.decode(serde.encode(now))
def test_component_ctor(): random.seed(5_432_671_244) A = 100 B = 200 C = 300 x_list = [ Foo( str(random.randint(0, A)), Complex(x=random.uniform(0, C), y=str(random.uniform(0, C))), b=random.uniform(0, B), ) for i in range(4) ] fields = [ Foo( a=str(random.randint(0, A)), b=random.uniform(0, B), c=Complex(x=str(random.uniform(0, C)), y=random.uniform(0, C)), ) for i in range(5) ] x_dict = { i: Foo( b=random.uniform(0, B), a=str(random.randint(0, A)), c=Complex(x=str(random.uniform(0, C)), y=str(random.uniform(0, C))), ) for i in range(6) } bar01 = Bar(x_list, input_fields=fields, x_dict=x_dict) bar02 = load_code(dump_code(bar01)) bar03 = load_json(dump_json(bar02)) def compare_tpes(x, y, z, tpe): assert tpe == type(x) == type(y) == type(z) def compare_vals(x, y, z): assert x == y == z compare_tpes(bar02.x_list, bar02.x_list, bar03.x_list, tpe=list) compare_tpes(bar02.x_dict, bar02.x_dict, bar03.x_dict, tpe=dict) compare_tpes(bar02.input_fields, bar02.input_fields, bar03.input_fields, tpe=list) compare_vals(len(bar02.x_list), len(bar02.x_list), len(bar03.x_list)) compare_vals(len(bar02.x_dict), len(bar02.x_dict), len(bar03.x_dict)) compare_vals( len(bar02.input_fields), len(bar02.input_fields), len(bar03.input_fields), ) compare_vals(bar02.x_list, bar02.x_list, bar03.x_list) compare_vals(bar02.x_dict, bar02.x_dict, bar03.x_dict) compare_vals(bar02.input_fields, bar02.input_fields, bar03.input_fields) baz01 = Baz(a="0", b="9", c=Complex(x="1", y="2"), d="42") baz02 = load_json(dump_json(baz01)) assert type(baz01) == type(baz02) assert baz01 == baz02
def test_code_serialization(e) -> None: assert e == serde.load_code(serde.dump_code(e))