def _map_resources(accessible_objects): """Makes new resource handle ops corresponding to existing resource tensors. Creates resource handle ops in the current default graph, whereas `accessible_objects` will be from an eager context. Resource mapping adds resource handle ops to the main GraphDef of a SavedModel, which allows the C++ loader API to interact with variables. Args: accessible_objects: A list of objects, some of which may contain resources, to create replacements for. Returns: A tuple of (object_map, resource_map): object_map: A dictionary mapping from object in `accessible_objects` to replacement objects created to hold the new resource tensors. resource_map: A dictionary mapping from resource tensors extracted from `accessible_objects` to newly created resource tensors. """ # TODO(allenl, rohanj): Map generic resources rather than just variables. # TODO(allenl): Handle MirroredVariables and other types of variables which # may need special casing. object_map = {} resource_map = {} for obj in accessible_objects: if resource_variable_ops.is_resource_variable(obj): new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj) object_map[obj] = new_variable resource_map[obj.handle] = new_variable.handle return object_map, resource_map
def _map_resources(accessible_objects): """Makes new resource handle ops corresponding to existing resource tensors. Creates resource handle ops in the current default graph, whereas `accessible_objects` will be from an eager context. Resource mapping adds resource handle ops to the main GraphDef of a SavedModel, which allows the C++ loader API to interact with variables. Args: accessible_objects: A list of objects, some of which may contain resources, to create replacements for. Returns: A tuple of (object_map, resource_map): object_map: A dictionary mapping from object in `accessible_objects` to replacement objects created to hold the new resource tensors. resource_map: A dictionary mapping from resource tensors extracted from `accessible_objects` to newly created resource tensors. """ # TODO(allenl, rohanj): Map generic resources rather than just variables. # TODO(allenl): Handle MirroredVariables and other types of variables which # may need special casing. object_map = {} resource_map = {} for obj in accessible_objects: if resource_variable_ops.is_resource_variable(obj): new_variable = resource_variable_ops.copy_to_graph_uninitialized( obj) object_map[obj] = new_variable resource_map[obj.handle] = new_variable.handle return object_map, resource_map
def map_resources(self): """Makes new resource handle ops corresponding to existing resource tensors. Creates resource handle ops in the current default graph, whereas `accessible_objects` will be from an eager context. Resource mapping adds resource handle ops to the main GraphDef of a SavedModel, which allows the C++ loader API to interact with variables. Returns: A tuple of (object_map, resource_map, asset_info): object_map: A dictionary mapping from object in `accessible_objects` to replacement objects created to hold the new resource tensors. resource_map: A dictionary mapping from resource tensors extracted from `accessible_objects` to newly created resource tensors. asset_info: An _AssetInfo tuple describing external assets referenced from accessible_objects. """ # Only makes sense when adding to the export Graph assert not context.executing_eagerly() # TODO(allenl): Handle MirroredVariables and other types of variables which # may need special casing. object_map = object_identity.ObjectIdentityDictionary() resource_map = {} asset_info = _AssetInfo( asset_defs=[], asset_initializers_by_resource={}, asset_filename_map={}, asset_index={}) for node_id, obj in enumerate(self.nodes): if isinstance(obj, tracking.TrackableResource): new_resource = obj._create_resource() # pylint: disable=protected-access resource_map[obj.resource_handle] = new_resource self.captured_tensor_node_ids[obj.resource_handle] = node_id elif resource_variable_ops.is_resource_variable(obj): new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj) object_map[obj] = new_variable resource_map[obj.handle] = new_variable.handle self.captured_tensor_node_ids[obj.handle] = node_id elif isinstance(obj, tracking.TrackableAsset): _process_asset(obj, asset_info, resource_map) self.captured_tensor_node_ids[obj.asset_path] = node_id for concrete_function in self.concrete_functions: for capture in concrete_function.captured_inputs: if (tensor_util.is_tensor(capture) and capture.dtype not in _UNCOPIABLE_DTYPES and capture not in self.captured_tensor_node_ids): copied_tensor = constant_op.constant( tensor_util.constant_value(capture)) node_id = len(self.nodes) node = _CapturedConstant( eager_tensor=capture, graph_tensor=copied_tensor) self.nodes.append(node) self.node_ids[capture] = node_id self.node_ids[node] = node_id self.captured_tensor_node_ids[capture] = node_id resource_map[capture] = copied_tensor return object_map, resource_map, asset_info
def map_resources(self): """Makes new resource handle ops corresponding to existing resource tensors. Creates resource handle ops in the current default graph, whereas `accessible_objects` will be from an eager context. Resource mapping adds resource handle ops to the main GraphDef of a SavedModel, which allows the C++ loader API to interact with variables. Returns: A tuple of (object_map, resource_map, asset_info): object_map: A dictionary mapping from object in `accessible_objects` to replacement objects created to hold the new resource tensors. resource_map: A dictionary mapping from resource tensors extracted from `accessible_objects` to newly created resource tensors. asset_info: An _AssetInfo tuple describing external assets referenced from accessible_objects. """ # Only makes sense when adding to the export Graph assert not context.executing_eagerly() # TODO(allenl): Handle MirroredVariables and other types of variables which # may need special casing. object_map = object_identity.ObjectIdentityDictionary() resource_map = {} asset_info = _AssetInfo( asset_defs=[], asset_initializers_by_resource={}, asset_filename_map={}, asset_index={}) for node_id, obj in enumerate(self.nodes): if isinstance(obj, tracking.TrackableResource): new_resource = obj._create_resource() # pylint: disable=protected-access resource_map[obj.resource_handle] = new_resource self.captured_tensor_node_ids[obj.resource_handle] = node_id elif resource_variable_ops.is_resource_variable(obj): new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj) object_map[obj] = new_variable resource_map[obj.handle] = new_variable.handle self.captured_tensor_node_ids[obj.handle] = node_id elif isinstance(obj, tracking.TrackableAsset): _process_asset(obj, asset_info, resource_map) self.captured_tensor_node_ids[obj.asset_path] = node_id for concrete_function in self.concrete_functions: for capture in concrete_function.captured_inputs: if (tensor_util.is_tensor(capture) and capture.dtype not in _UNCOPIABLE_DTYPES and capture not in self.captured_tensor_node_ids): copied_tensor = constant_op.constant( tensor_util.constant_value(capture)) node_id = len(self.nodes) node = _CapturedConstant( eager_tensor=capture, graph_tensor=copied_tensor) self.nodes.append(node) self.node_ids[capture] = node_id self.node_ids[node] = node_id self.captured_tensor_node_ids[capture] = node_id resource_map[capture] = copied_tensor return object_map, resource_map, asset_info
def testCopyToGraphUninitialized(self): v = resource_variable_ops.ResourceVariable([0, 1, 2, 3]) copy_to_graph = ops.Graph() with copy_to_graph.as_default(): # Intentionally testing v1 behavior copied = resource_variable_ops.copy_to_graph_uninitialized(v) self.assertEqual(v.name, copied.name) with self.session(copy_to_graph) as session: with self.assertRaises(errors.InvalidArgumentError): session.run(copied.initializer)
def testCopyToGraphUninitialized(self): v = resource_variable_ops.ResourceVariable([0, 1, 2, 3]) copy_to_graph = ops.Graph() with copy_to_graph.as_default(): # Intentionally testing v1 behavior copied = resource_variable_ops.copy_to_graph_uninitialized(v) self.assertEqual(v.name, copied.name) with self.session(copy_to_graph) as session: with self.assertRaises(errors.InvalidArgumentError): session.run(copied.initializer)
def _map_resources(accessible_objects): """Makes new resource handle ops corresponding to existing resource tensors. Creates resource handle ops in the current default graph, whereas `accessible_objects` will be from an eager context. Resource mapping adds resource handle ops to the main GraphDef of a SavedModel, which allows the C++ loader API to interact with variables. Args: accessible_objects: A list of objects, some of which may contain resources, to create replacements for. Returns: A tuple of (object_map, resource_map, asset_info): object_map: A dictionary mapping from object in `accessible_objects` to replacement objects created to hold the new resource tensors. resource_map: A dictionary mapping from resource tensors extracted from `accessible_objects` to newly created resource tensors. asset_info: An _AssetInfo tuple describing external assets referenced from accessible_objects. """ # TODO(allenl): Handle MirroredVariables and other types of variables which # may need special casing. object_map = util.ObjectIdentityDictionary() resource_map = {} asset_info = _AssetInfo(asset_defs=[], asset_initializers_by_resource={}, asset_filename_map={}, asset_index={}) for obj in accessible_objects: if isinstance(obj, tracking.TrackableResource): new_resource = obj.create_resource() resource_map[obj.resource_handle] = new_resource elif resource_variable_ops.is_resource_variable(obj): new_variable = resource_variable_ops.copy_to_graph_uninitialized( obj) object_map[obj] = new_variable resource_map[obj.handle] = new_variable.handle elif isinstance(obj, tracking.TrackableAsset): _process_asset(obj, asset_info, resource_map) return object_map, resource_map, asset_info
def _map_resources(accessible_objects): """Makes new resource handle ops corresponding to existing resource tensors. Creates resource handle ops in the current default graph, whereas `accessible_objects` will be from an eager context. Resource mapping adds resource handle ops to the main GraphDef of a SavedModel, which allows the C++ loader API to interact with variables. Args: accessible_objects: A list of objects, some of which may contain resources, to create replacements for. Returns: A tuple of (object_map, resource_map, asset_info): object_map: A dictionary mapping from object in `accessible_objects` to replacement objects created to hold the new resource tensors. resource_map: A dictionary mapping from resource tensors extracted from `accessible_objects` to newly created resource tensors. asset_info: An _AssetInfo tuple describing external assets referenced from accessible_objects. """ # TODO(allenl): Handle MirroredVariables and other types of variables which # may need special casing. object_map = util.ObjectIdentityDictionary() resource_map = {} asset_info = _AssetInfo( asset_defs=[], asset_initializers_by_resource={}, asset_filename_map={}, asset_index={}) for obj in accessible_objects: if isinstance(obj, tracking.TrackableResource): new_resource = obj.create_resource() resource_map[obj.resource_handle] = new_resource elif resource_variable_ops.is_resource_variable(obj): new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj) object_map[obj] = new_variable resource_map[obj.handle] = new_variable.handle elif isinstance(obj, tracking.TrackableAsset): _process_asset(obj, asset_info, resource_map) return object_map, resource_map, asset_info
def map_resources(self): """Makes new resource handle ops corresponding to existing resource tensors. Creates resource handle ops in the current default graph, whereas `accessible_objects` will be from an eager context. Resource mapping adds resource handle ops to the main GraphDef of a SavedModel, which allows the C++ loader API to interact with variables. Returns: A tuple of (object_map, resource_map, asset_info): object_map: A dictionary mapping from object in `accessible_objects` to replacement objects created to hold the new resource tensors. resource_map: A dictionary mapping from resource tensors extracted from `accessible_objects` to newly created resource tensors. asset_info: An _AssetInfo tuple describing external assets referenced from accessible_objects. """ # Only makes sense when adding to the export Graph assert not context.executing_eagerly() # TODO(allenl): Handle MirroredVariables and other types of variables which # may need special casing. object_map = object_identity.ObjectIdentityDictionary() resource_map = {} asset_info = _AssetInfo( asset_defs=[], asset_initializers_by_resource={}, asset_filename_map={}, asset_index={}) for node_id, obj in enumerate(self.nodes): if isinstance(obj, tracking.CapturableResource): # pylint: disable=protected-access with ops.device(obj._resource_device): new_resource = obj._create_resource() # pylint: enable=protected-access resource_map[obj.resource_handle] = new_resource self.captured_tensor_node_ids[obj.resource_handle] = node_id elif (ds_values.is_distributed_variable(obj) or resource_variable_ops.is_resource_variable(obj)): obj_to_copy = obj.primary if ds_values.is_distributed_variable( obj) else obj new_variable = resource_variable_ops.copy_to_graph_uninitialized( obj_to_copy) if ds_values.is_distributed_variable(obj): self.captured_tensor_node_ids[obj] = node_id object_map[obj] = new_variable resource_map[obj.handle] = new_variable.handle self.captured_tensor_node_ids[obj.handle] = node_id elif isinstance(obj, tracking.Asset): _process_asset(obj, asset_info, resource_map) self.captured_tensor_node_ids[obj.asset_path] = node_id for concrete_function in self.concrete_functions: if not concrete_function.graph.saveable: raise ValueError( ("Unable to save function {name} for the following reason(s):\n" + "\n".join(concrete_function.graph.saving_errors)) .format(name=concrete_function.name)) for capture in concrete_function.captured_inputs: if (tensor_util.is_tensor(capture) and capture.dtype not in _UNCOPIABLE_DTYPES and capture not in self.captured_tensor_node_ids): capture_constant_value = tensor_util.constant_value(capture) if capture_constant_value is None: raise ValueError( ("Attempted to save a function {} which references a symbolic " "Tensor {} that is not a simple constant. This is not " "supported.").format(concrete_function.name, capture)) copied_tensor = constant_op.constant(capture_constant_value) node_id = len(self.nodes) node = _CapturedConstant( eager_tensor=capture, graph_tensor=copied_tensor) self.nodes.append(node) self.node_ids[capture] = node_id self.node_ids[node] = node_id self.captured_tensor_node_ids[capture] = node_id resource_map[capture] = copied_tensor return object_map, resource_map, asset_info
def map_resources(self): """Makes new resource handle ops corresponding to existing resource tensors. Creates resource handle ops in the current default graph, whereas `accessible_objects` will be from an eager context. Resource mapping adds resource handle ops to the main GraphDef of a SavedModel, which allows the C++ loader API to interact with variables. Returns: A tuple of (object_map, resource_map, asset_info): object_map: A dictionary mapping from object in `accessible_objects` to replacement objects created to hold the new resource tensors. resource_map: A dictionary mapping from resource tensors extracted from `accessible_objects` to newly created resource tensors. asset_info: An _AssetInfo tuple describing external assets referenced from accessible_objects. """ # Only makes sense when adding to the export Graph assert not context.executing_eagerly() # TODO(allenl): Handle MirroredVariables and other types of variables which # may need special casing. object_map = object_identity.ObjectIdentityDictionary() resource_map = {} asset_info = _AssetInfo(asset_defs=[], asset_initializers_by_resource={}, asset_filename_map={}, asset_index={}) for node_id, obj in enumerate(self.nodes): if isinstance(obj, tracking.CapturableResource): new_obj = object_map[obj] = copy.copy(obj) # pylint: disable=protected-access with ops.device(obj._resource_device): new_resource = new_obj._create_resource() new_obj._resource_handle = new_resource # pylint: enable=protected-access resource_map[obj.resource_handle] = new_resource self.captured_tensor_node_ids[obj.resource_handle] = node_id elif (ds_values.is_distributed_variable(obj) or resource_variable_ops.is_resource_variable(obj)): obj_to_copy = obj.primary if ds_values.is_distributed_variable( obj) else obj new_variable = resource_variable_ops.copy_to_graph_uninitialized( obj_to_copy) if ds_values.is_distributed_variable(obj): self.captured_tensor_node_ids[obj] = node_id for v in obj.values: object_map[v] = new_variable resource_map[v.handle] = new_variable.handle self.captured_tensor_node_ids[v.handle] = node_id object_map[obj] = new_variable resource_map[obj.handle] = new_variable.handle self.captured_tensor_node_ids[obj.handle] = node_id elif isinstance(obj, tracking.Asset): _process_asset(obj, asset_info, resource_map) self.captured_tensor_node_ids[obj.asset_path] = node_id # Note: some concrete functions can have been realized when tracing other # functions, and might closure-capture tensors from their parent functions. # This is normal, but it means those concrete functions can't be serialized # as their own independent endpoints, so we filter them out here. bad_functions = [] for concrete_function in self.concrete_functions: if not concrete_function.graph.saveable: raise ValueError(( "Unable to save function {name} for the following reason(s):\n" + "\n".join(concrete_function.graph.saving_errors)).format( name=concrete_function.name)) for capture in concrete_function.captured_inputs: if (tensor_util.is_tensor(capture) and capture.dtype not in _UNCOPIABLE_DTYPES and capture not in self.captured_tensor_node_ids): if hasattr(capture, "_cached_variable"): if concrete_function not in self.wrapped_functions: wrapped = self.wrapped_functions[ concrete_function] = ( function_serialization. wrap_cached_variables(concrete_function)) self.function_name_map[compat.as_text( concrete_function.name)] = (compat.as_text( wrapped.name)) continue capture_constant_value = tensor_util.constant_value( capture) if capture_constant_value is None: bad_functions.append(concrete_function) continue copied_tensor = constant_op.constant( capture_constant_value) node_id = len(self.nodes) node = _CapturedConstant(eager_tensor=capture, graph_tensor=copied_tensor) self.nodes.append(node) self.node_ids[capture] = node_id self.node_ids[node] = node_id self.captured_tensor_node_ids[capture] = node_id resource_map[capture] = copied_tensor self.concrete_functions = [ self.wrapped_functions.get(x, x) for x in self.concrete_functions if x not in bad_functions ] return object_map, resource_map, asset_info