Exemple #1
0
def _write_object_proto(obj, proto, asset_file_def_index):
    """Saves an object into SavedObject proto."""
    if isinstance(obj, tracking.TrackableAsset):
        proto.asset.SetInParent()
        proto.asset.asset_file_def_index = asset_file_def_index[obj]
    elif resource_variable_ops.is_resource_variable(obj):
        proto.variable.SetInParent()
        proto.variable.trainable = obj.trainable
        proto.variable.dtype = obj.dtype.as_datatype_enum
        proto.variable.shape.CopyFrom(obj.shape.as_proto())
    elif isinstance(obj, def_function.Function):
        proto.function.CopyFrom(function_serialization.serialize_function(obj))
    elif isinstance(obj, defun.ConcreteFunction):
        proto.bare_concrete_function.CopyFrom(
            function_serialization.serialize_bare_concrete_function(obj))
    elif isinstance(obj, _CapturedConstant):
        proto.constant.operation = obj.graph_tensor.op.name
    elif isinstance(obj, tracking.TrackableResource):
        proto.resource.SetInParent()
    else:
        registered_type_proto = revived_types.serialize(obj)
        if registered_type_proto is None:
            # Fallback for types with no matching registration
            registered_type_proto = saved_object_graph_pb2.SavedUserObject(
                identifier="_generic_user_object",
                version=versions_pb2.VersionDef(producer=1,
                                                min_consumer=1,
                                                bad_consumers=[]))
        proto.user_object.CopyFrom(registered_type_proto)
Exemple #2
0
def _write_object_proto(obj, proto, asset_file_def_index, node_ids):
  """Saves an object into SavedObject proto."""
  if isinstance(obj, tracking.TrackableAsset):
    proto.asset.SetInParent()
    proto.asset.asset_file_def_index = asset_file_def_index[obj]
  elif resource_variable_ops.is_resource_variable(obj):
    proto.variable.SetInParent()
    proto.variable.trainable = obj.trainable
    proto.variable.dtype = obj.dtype.as_datatype_enum
    proto.variable.shape.CopyFrom(obj.shape.as_proto())
  elif isinstance(obj, def_function.Function):
    proto.function.CopyFrom(
        function_serialization.serialize_function(obj, node_ids))
  elif isinstance(obj, defun.ConcreteFunction):
    proto.concrete_function.CopyFrom(
        function_serialization.serialize_concrete_function(obj, node_ids))
  else:
    registered_type_proto = revived_types.serialize(obj)
    if registered_type_proto is None:
      # Fallback for types with no matching registration
      registered_type_proto = saved_object_graph_pb2.SavedUserObject(
          identifier="_generic_user_object",
          version=versions_pb2.VersionDef(
              producer=1, min_consumer=1, bad_consumers=[]))
    proto.user_object.CopyFrom(registered_type_proto)
Exemple #3
0
def _write_object_proto(obj, proto, asset_file_def_index):
    """Saves an object into SavedObject proto."""
    if isinstance(obj, tracking.TrackableAsset):
        proto.asset.SetInParent()
        proto.asset.asset_file_def_index = asset_file_def_index[obj]
    elif resource_variable_ops.is_resource_variable(obj):
        proto.variable.SetInParent()
        if not obj.name.endswith(":0"):
            raise ValueError("Cowardly refusing to save variable %s because of"
                             " unexpected suffix which won't be restored.")
        proto.variable.name = meta_graph._op_name(obj.name)  # pylint: disable=protected-access
        proto.variable.trainable = obj.trainable
        proto.variable.dtype = obj.dtype.as_datatype_enum
        proto.variable.synchronization = obj.synchronization.value
        proto.variable.aggregation = obj.aggregation.value
        proto.variable.shape.CopyFrom(obj.shape.as_proto())
    elif isinstance(obj, def_function.Function):
        proto.function.CopyFrom(function_serialization.serialize_function(obj))
    elif isinstance(obj, defun.ConcreteFunction):
        proto.bare_concrete_function.CopyFrom(
            function_serialization.serialize_bare_concrete_function(obj))
    elif isinstance(obj, _CapturedConstant):
        proto.constant.operation = obj.graph_tensor.op.name
    elif isinstance(obj, tracking.CapturableResource):
        proto.resource.device = obj._resource_device  # pylint: disable=protected-access
    else:
        registered_type_proto = revived_types.serialize(obj)
        if registered_type_proto is None:
            # Fallback for types with no matching registration
            registered_type_proto = saved_object_graph_pb2.SavedUserObject(
                identifier="_generic_user_object",
                version=versions_pb2.VersionDef(producer=1,
                                                min_consumer=1,
                                                bad_consumers=[]))
        proto.user_object.CopyFrom(registered_type_proto)
Exemple #4
0
def _write_object_proto(obj, proto, asset_file_def_index):
  """Saves an object into SavedObject proto."""
  if isinstance(obj, tracking.TrackableAsset):
    proto.asset.SetInParent()
    proto.asset.asset_file_def_index = asset_file_def_index[obj]
  elif resource_variable_ops.is_resource_variable(obj):
    proto.variable.SetInParent()
    if not obj.name.endswith(":0"):
      raise ValueError("Cowardly refusing to save variable %s because of"
                       " unexpected suffix which won't be restored.")
    proto.variable.name = meta_graph._op_name(obj.name)  # pylint: disable=protected-access
    proto.variable.trainable = obj.trainable
    proto.variable.dtype = obj.dtype.as_datatype_enum
    proto.variable.synchronization = obj.synchronization.value
    proto.variable.aggregation = obj.aggregation.value
    proto.variable.shape.CopyFrom(obj.shape.as_proto())
  elif isinstance(obj, def_function.Function):
    proto.function.CopyFrom(
        function_serialization.serialize_function(obj))
  elif isinstance(obj, defun.ConcreteFunction):
    proto.bare_concrete_function.CopyFrom(
        function_serialization.serialize_bare_concrete_function(obj))
  elif isinstance(obj, _CapturedConstant):
    proto.constant.operation = obj.graph_tensor.op.name
  elif isinstance(obj, tracking.CapturableResource):
    proto.resource.device = obj._resource_device  # pylint: disable=protected-access
  else:
    registered_type_proto = revived_types.serialize(obj)
    if registered_type_proto is None:
      # Fallback for types with no matching registration
      registered_type_proto = saved_object_graph_pb2.SavedUserObject(
          identifier="_generic_user_object",
          version=versions_pb2.VersionDef(
              producer=1, min_consumer=1, bad_consumers=[]))
    proto.user_object.CopyFrom(registered_type_proto)
Exemple #5
0
 def test_most_recent_version_saved(self):
     serialized = revived_types.serialize(CustomTestClass(None))
     self.assertEqual([3], serialized.version.bad_consumers)
     deserialized, _ = revived_types.deserialize(serialized)
     self.assertIsInstance(deserialized, CustomTestClass)
     self.assertEqual(4, deserialized.version)
Exemple #6
0
 def test_save_typecheck(self):
     self.assertIs(revived_types.serialize(autotrackable.AutoTrackable()),
                   None)
 def test_most_recent_version_saved(self):
   serialized = revived_types.serialize(CustomTestClass(None))
   self.assertEqual([3], serialized.version.bad_consumers)
   deserialized, _ = revived_types.deserialize(serialized)
   self.assertIsInstance(deserialized, CustomTestClass)
   self.assertEqual(4, deserialized.version)
 def test_save_typecheck(self):
   self.assertIs(revived_types.serialize(tracking.AutoTrackable()), None)
Exemple #9
0
 def test_save_typecheck(self):
     self.assertIs(revived_types.serialize(tracking.Checkpointable()), None)