def test_serialize_refs_roundtrip_bytes(): fwd = lambda model, X, is_train: (X, lambda dY: dY) model_a = Model("a", fwd) model = Model("test", fwd, refs={"a": model_a, "b": None}).initialize() with pytest.raises(ValueError): # ref not in nodes model.to_bytes() model = Model("test", fwd, refs={ "a": model_a, "b": None }, layers=[model_a]) assert model.ref_names == ("a", "b") model_bytes = model.to_bytes() with pytest.raises(ValueError): Model("test", fwd).from_bytes(model_bytes) new_model = Model("test", fwd, layers=[model_a]) new_model.from_bytes(model_bytes) assert new_model.ref_names == ("a", "b")
def test_simple_model_roundtrip_bytes_serializable_attrs(): fwd = lambda model, X, is_train: (X, lambda dY: dY) attr = SerializableAttr() assert attr.value == "foo" assert attr.to_bytes() == b"foo" model = Model("test", fwd, attrs={"test": attr}) model.initialize() @serialize_attr.register(SerializableAttr) def serialize_attr_custom(_, value, name, model): return value.to_bytes() @deserialize_attr.register(SerializableAttr) def deserialize_attr_custom(_, value, name, model): return SerializableAttr().from_bytes(value) model_bytes = model.to_bytes() model = model.from_bytes(model_bytes) assert "test" in model.attrs assert model.attrs["test"].value == "foo from bytes"