Example #1
0
def test_serialize_refs_roundtrip_bytes():
    fwd = lambda model, X, is_train: (X, lambda dY: dY)
    model_a = Model("a", fwd)
    model = Model("test", fwd, refs={"a": model_a, "b": None}).initialize()
    with pytest.raises(ValueError):  # ref not in nodes
        model.to_bytes()
    model = Model("test",
                  fwd,
                  refs={
                      "a": model_a,
                      "b": None
                  },
                  layers=[model_a])
    assert model.ref_names == ("a", "b")
    model_bytes = model.to_bytes()
    with pytest.raises(ValueError):
        Model("test", fwd).from_bytes(model_bytes)
    new_model = Model("test", fwd, layers=[model_a])
    new_model.from_bytes(model_bytes)
    assert new_model.ref_names == ("a", "b")
Example #2
0
def test_simple_model_roundtrip_bytes_serializable_attrs():
    fwd = lambda model, X, is_train: (X, lambda dY: dY)
    attr = SerializableAttr()
    assert attr.value == "foo"
    assert attr.to_bytes() == b"foo"
    model = Model("test", fwd, attrs={"test": attr})
    model.initialize()

    @serialize_attr.register(SerializableAttr)
    def serialize_attr_custom(_, value, name, model):
        return value.to_bytes()

    @deserialize_attr.register(SerializableAttr)
    def deserialize_attr_custom(_, value, name, model):
        return SerializableAttr().from_bytes(value)

    model_bytes = model.to_bytes()
    model = model.from_bytes(model_bytes)
    assert "test" in model.attrs
    assert model.attrs["test"].value == "foo from bytes"