def _write_object_proto(obj, proto, asset_file_def_index): """Saves an object into SavedObject proto.""" if isinstance(obj, tracking.TrackableAsset): proto.asset.SetInParent() proto.asset.asset_file_def_index = asset_file_def_index[obj] elif resource_variable_ops.is_resource_variable(obj): proto.variable.SetInParent() proto.variable.trainable = obj.trainable proto.variable.dtype = obj.dtype.as_datatype_enum proto.variable.shape.CopyFrom(obj.shape.as_proto()) elif isinstance(obj, def_function.Function): proto.function.CopyFrom(function_serialization.serialize_function(obj)) elif isinstance(obj, defun.ConcreteFunction): proto.bare_concrete_function.CopyFrom( function_serialization.serialize_bare_concrete_function(obj)) elif isinstance(obj, _CapturedConstant): proto.constant.operation = obj.graph_tensor.op.name elif isinstance(obj, tracking.TrackableResource): proto.resource.SetInParent() else: registered_type_proto = revived_types.serialize(obj) if registered_type_proto is None: # Fallback for types with no matching registration registered_type_proto = saved_object_graph_pb2.SavedUserObject( identifier="_generic_user_object", version=versions_pb2.VersionDef(producer=1, min_consumer=1, bad_consumers=[])) proto.user_object.CopyFrom(registered_type_proto)
def _write_object_proto(obj, proto, asset_file_def_index, node_ids): """Saves an object into SavedObject proto.""" if isinstance(obj, tracking.TrackableAsset): proto.asset.SetInParent() proto.asset.asset_file_def_index = asset_file_def_index[obj] elif resource_variable_ops.is_resource_variable(obj): proto.variable.SetInParent() proto.variable.trainable = obj.trainable proto.variable.dtype = obj.dtype.as_datatype_enum proto.variable.shape.CopyFrom(obj.shape.as_proto()) elif isinstance(obj, def_function.Function): proto.function.CopyFrom( function_serialization.serialize_function(obj, node_ids)) elif isinstance(obj, defun.ConcreteFunction): proto.concrete_function.CopyFrom( function_serialization.serialize_concrete_function(obj, node_ids)) else: registered_type_proto = revived_types.serialize(obj) if registered_type_proto is None: # Fallback for types with no matching registration registered_type_proto = saved_object_graph_pb2.SavedUserObject( identifier="_generic_user_object", version=versions_pb2.VersionDef( producer=1, min_consumer=1, bad_consumers=[])) proto.user_object.CopyFrom(registered_type_proto)
def _write_object_proto(obj, proto, asset_file_def_index): """Saves an object into SavedObject proto.""" if isinstance(obj, tracking.TrackableAsset): proto.asset.SetInParent() proto.asset.asset_file_def_index = asset_file_def_index[obj] elif resource_variable_ops.is_resource_variable(obj): proto.variable.SetInParent() if not obj.name.endswith(":0"): raise ValueError("Cowardly refusing to save variable %s because of" " unexpected suffix which won't be restored.") proto.variable.name = meta_graph._op_name(obj.name) # pylint: disable=protected-access proto.variable.trainable = obj.trainable proto.variable.dtype = obj.dtype.as_datatype_enum proto.variable.synchronization = obj.synchronization.value proto.variable.aggregation = obj.aggregation.value proto.variable.shape.CopyFrom(obj.shape.as_proto()) elif isinstance(obj, def_function.Function): proto.function.CopyFrom(function_serialization.serialize_function(obj)) elif isinstance(obj, defun.ConcreteFunction): proto.bare_concrete_function.CopyFrom( function_serialization.serialize_bare_concrete_function(obj)) elif isinstance(obj, _CapturedConstant): proto.constant.operation = obj.graph_tensor.op.name elif isinstance(obj, tracking.CapturableResource): proto.resource.device = obj._resource_device # pylint: disable=protected-access else: registered_type_proto = revived_types.serialize(obj) if registered_type_proto is None: # Fallback for types with no matching registration registered_type_proto = saved_object_graph_pb2.SavedUserObject( identifier="_generic_user_object", version=versions_pb2.VersionDef(producer=1, min_consumer=1, bad_consumers=[])) proto.user_object.CopyFrom(registered_type_proto)
def _write_object_proto(obj, proto, asset_file_def_index): """Saves an object into SavedObject proto.""" if isinstance(obj, tracking.TrackableAsset): proto.asset.SetInParent() proto.asset.asset_file_def_index = asset_file_def_index[obj] elif resource_variable_ops.is_resource_variable(obj): proto.variable.SetInParent() if not obj.name.endswith(":0"): raise ValueError("Cowardly refusing to save variable %s because of" " unexpected suffix which won't be restored.") proto.variable.name = meta_graph._op_name(obj.name) # pylint: disable=protected-access proto.variable.trainable = obj.trainable proto.variable.dtype = obj.dtype.as_datatype_enum proto.variable.synchronization = obj.synchronization.value proto.variable.aggregation = obj.aggregation.value proto.variable.shape.CopyFrom(obj.shape.as_proto()) elif isinstance(obj, def_function.Function): proto.function.CopyFrom( function_serialization.serialize_function(obj)) elif isinstance(obj, defun.ConcreteFunction): proto.bare_concrete_function.CopyFrom( function_serialization.serialize_bare_concrete_function(obj)) elif isinstance(obj, _CapturedConstant): proto.constant.operation = obj.graph_tensor.op.name elif isinstance(obj, tracking.CapturableResource): proto.resource.device = obj._resource_device # pylint: disable=protected-access else: registered_type_proto = revived_types.serialize(obj) if registered_type_proto is None: # Fallback for types with no matching registration registered_type_proto = saved_object_graph_pb2.SavedUserObject( identifier="_generic_user_object", version=versions_pb2.VersionDef( producer=1, min_consumer=1, bad_consumers=[])) proto.user_object.CopyFrom(registered_type_proto)
def test_most_recent_version_saved(self): serialized = revived_types.serialize(CustomTestClass(None)) self.assertEqual([3], serialized.version.bad_consumers) deserialized, _ = revived_types.deserialize(serialized) self.assertIsInstance(deserialized, CustomTestClass) self.assertEqual(4, deserialized.version)
def test_save_typecheck(self): self.assertIs(revived_types.serialize(autotrackable.AutoTrackable()), None)
def test_save_typecheck(self): self.assertIs(revived_types.serialize(tracking.AutoTrackable()), None)
def test_save_typecheck(self): self.assertIs(revived_types.serialize(tracking.Checkpointable()), None)