Ejemplo n.º 1
0
def _write_object_proto(obj, proto, asset_file_def_index):
    """Saves an object into SavedObject proto."""
    if isinstance(obj, tracking.TrackableAsset):
        proto.asset.SetInParent()
        proto.asset.asset_file_def_index = asset_file_def_index[obj]
    elif resource_variable_ops.is_resource_variable(obj):
        proto.variable.SetInParent()
        proto.variable.trainable = obj.trainable
        proto.variable.dtype = obj.dtype.as_datatype_enum
        proto.variable.shape.CopyFrom(obj.shape.as_proto())
    elif isinstance(obj, def_function.Function):
        proto.function.CopyFrom(function_serialization.serialize_function(obj))
    elif isinstance(obj, defun.ConcreteFunction):
        proto.bare_concrete_function.CopyFrom(
            function_serialization.serialize_bare_concrete_function(obj))
    elif isinstance(obj, _CapturedConstant):
        proto.constant.operation = obj.graph_tensor.op.name
    elif isinstance(obj, tracking.TrackableResource):
        proto.resource.SetInParent()
    else:
        registered_type_proto = revived_types.serialize(obj)
        if registered_type_proto is None:
            # Fallback for types with no matching registration
            registered_type_proto = saved_object_graph_pb2.SavedUserObject(
                identifier="_generic_user_object",
                version=versions_pb2.VersionDef(producer=1,
                                                min_consumer=1,
                                                bad_consumers=[]))
        proto.user_object.CopyFrom(registered_type_proto)
 def test_min_consumer_version(self):
     nothing_matches = revived_types.deserialize(
         saved_object_graph_pb2.SavedUserObject(
             identifier="test_type",
             version=versions_pb2.VersionDef(producer=5,
                                             min_consumer=5,
                                             bad_consumers=[])))
     self.assertIs(nothing_matches, None)
 def test_min_producer_version(self):
     deserialized, _ = revived_types.deserialize(
         saved_object_graph_pb2.SavedUserObject(
             identifier="test_type",
             version=versions_pb2.VersionDef(producer=3,
                                             min_consumer=0,
                                             bad_consumers=[])))
     self.assertEqual(3, deserialized.version)
 def test_load_identifier_not_found(self):
     nothing_matches = revived_types.deserialize(
         saved_object_graph_pb2.SavedUserObject(
             identifier="_unregistered_type",
             version=versions_pb2.VersionDef(producer=1,
                                             min_consumer=1,
                                             bad_consumers=[])))
     self.assertIs(nothing_matches, None)
Ejemplo n.º 5
0
 def to_proto(self):
   """Create a SavedUserObject proto."""
   # For now wrappers just use dependencies to save their state, so the
   # SavedUserObject doesn't depend on the object being saved.
   # TODO(allenl): Add a wrapper which uses its own proto.
   return saved_object_graph_pb2.SavedUserObject(
       identifier=self.identifier,
       version=versions_pb2.VersionDef(
           producer=self.version,
           min_consumer=self._min_consumer_version,
           bad_consumers=self._bad_consumers))