def f(): v = variables.Variable([1.0]) self.assertTrue(ds_values.is_distributed_variable(v)) # Slot variables are created in the first call to apply_gradients. optimizer.apply_gradients([(ops.convert_to_tensor([1.0]), v)]) self.assertTrue(optimizer.get_slot_names()) for name in optimizer.get_slot_names(): slot = optimizer.get_slot(v, name) self.assertIsNotNone(slot) self.assertTrue(ds_values.is_distributed_variable(slot))
def _get_tensor_from_node(self, node_id): """Resolves a node id into a tensor to be captured for a function.""" with ops.init_scope(): obj = self._nodes[node_id] if ds_values.is_distributed_variable(obj): return obj elif resource_variable_ops.is_resource_variable(obj): return obj.handle elif isinstance(obj, tracking.Asset): return obj.asset_path elif tensor_util.is_tensor(obj): return obj elif isinstance(obj, tracking.CapturableResource): # Note: this executes restored functions in the CapturableResource. return obj.resource_handle raise ValueError("Can't convert node %s to tensor" % (type(obj)))
def _setup_functions_captures(self): """Setup captures and variables in restored functions.""" concrete_functions = sorted(self._proto.concrete_functions.items()) for name, proto in concrete_functions: concrete_function = self._concrete_functions[name] bound_inputs = [ self._get_tensor_from_node(node_id) for node_id in proto.bound_inputs ] bound_variables = [ self._nodes[node_id] for node_id in proto.bound_inputs if self._proto.nodes[node_id].WhichOneof("kind") == "variable" ] # TODO(andresp): This is only injecting the captured inputs into the # concrete function, note that we did not modify the FuncGraph # itself. concrete_function._captured_inputs = bound_inputs # pylint: disable=protected-access concrete_function._func_graph.variables = bound_variables # pylint: disable=protected-access if bound_inputs: for bound_input, internal_capture in zip( bound_inputs, concrete_function.inputs[-len(bound_inputs):]): if ds_values.is_distributed_variable(bound_input): concrete_function.graph.capture_distributed_variable( bound_input, internal_capture) else: concrete_function.graph._captures[ops.tensor_id( bound_input)] = ( # pylint: disable=protected-access bound_input, internal_capture) if internal_capture.dtype == dtypes.resource: if resource_variable_ops.is_resource_variable( bound_input): try: handle = bound_input.handle except ValueError: # For mirrored variables we'll copy handle data for components # as they get captured. pass else: custom_gradient.copy_handle_data( handle, internal_capture) else: custom_gradient.copy_handle_data( bound_input, internal_capture) # Setting "captures" first means "capture" won't create a new # placeholder for this input. concrete_function.graph.capture(bound_input)
def get_cross_replica_handle(x): return _unused_handle() if ds_values.is_distributed_variable(x) else x
def get_in_replica_handle(x): return x.handle if ds_values.is_distributed_variable(x) else x
def map_resources(self): """Makes new resource handle ops corresponding to existing resource tensors. Creates resource handle ops in the current default graph, whereas `accessible_objects` will be from an eager context. Resource mapping adds resource handle ops to the main GraphDef of a SavedModel, which allows the C++ loader API to interact with variables. Returns: A tuple of (object_map, resource_map, asset_info): object_map: A dictionary mapping from object in `accessible_objects` to replacement objects created to hold the new resource tensors. resource_map: A dictionary mapping from resource tensors extracted from `accessible_objects` to newly created resource tensors. asset_info: An _AssetInfo tuple describing external assets referenced from accessible_objects. """ # Only makes sense when adding to the export Graph assert not context.executing_eagerly() # TODO(allenl): Handle MirroredVariables and other types of variables which # may need special casing. object_map = object_identity.ObjectIdentityDictionary() resource_map = {} asset_info = _AssetInfo( asset_defs=[], asset_initializers_by_resource={}, asset_filename_map={}, asset_index={}) for node_id, obj in enumerate(self.nodes): if isinstance(obj, tracking.CapturableResource): # pylint: disable=protected-access with ops.device(obj._resource_device): new_resource = obj._create_resource() # pylint: enable=protected-access resource_map[obj.resource_handle] = new_resource self.captured_tensor_node_ids[obj.resource_handle] = node_id elif (ds_values.is_distributed_variable(obj) or resource_variable_ops.is_resource_variable(obj)): obj_to_copy = obj.primary if ds_values.is_distributed_variable( obj) else obj new_variable = resource_variable_ops.copy_to_graph_uninitialized( obj_to_copy) if ds_values.is_distributed_variable(obj): self.captured_tensor_node_ids[obj] = node_id object_map[obj] = new_variable resource_map[obj.handle] = new_variable.handle self.captured_tensor_node_ids[obj.handle] = node_id elif isinstance(obj, tracking.Asset): _process_asset(obj, asset_info, resource_map) self.captured_tensor_node_ids[obj.asset_path] = node_id for concrete_function in self.concrete_functions: if not concrete_function.graph.saveable: raise ValueError( ("Unable to save function {name} for the following reason(s):\n" + "\n".join(concrete_function.graph.saving_errors)) .format(name=concrete_function.name)) for capture in concrete_function.captured_inputs: if (tensor_util.is_tensor(capture) and capture.dtype not in _UNCOPIABLE_DTYPES and capture not in self.captured_tensor_node_ids): capture_constant_value = tensor_util.constant_value(capture) if capture_constant_value is None: raise ValueError( ("Attempted to save a function {} which references a symbolic " "Tensor {} that is not a simple constant. This is not " "supported.").format(concrete_function.name, capture)) copied_tensor = constant_op.constant(capture_constant_value) node_id = len(self.nodes) node = _CapturedConstant( eager_tensor=capture, graph_tensor=copied_tensor) self.nodes.append(node) self.node_ids[capture] = node_id self.node_ids[node] = node_id self.captured_tensor_node_ids[capture] = node_id resource_map[capture] = copied_tensor return object_map, resource_map, asset_info
def map_resources(self): """Makes new resource handle ops corresponding to existing resource tensors. Creates resource handle ops in the current default graph, whereas `accessible_objects` will be from an eager context. Resource mapping adds resource handle ops to the main GraphDef of a SavedModel, which allows the C++ loader API to interact with variables. Returns: A tuple of (object_map, resource_map, asset_info): object_map: A dictionary mapping from object in `accessible_objects` to replacement objects created to hold the new resource tensors. resource_map: A dictionary mapping from resource tensors extracted from `accessible_objects` to newly created resource tensors. asset_info: An _AssetInfo tuple describing external assets referenced from accessible_objects. """ # Only makes sense when adding to the export Graph assert not context.executing_eagerly() # TODO(allenl): Handle MirroredVariables and other types of variables which # may need special casing. object_map = object_identity.ObjectIdentityDictionary() resource_map = {} asset_info = _AssetInfo(asset_defs=[], asset_initializers_by_resource={}, asset_filename_map={}, asset_index={}) for node_id, obj in enumerate(self.nodes): if isinstance(obj, tracking.CapturableResource): new_obj = object_map[obj] = copy.copy(obj) # pylint: disable=protected-access with ops.device(obj._resource_device): new_resource = new_obj._create_resource() new_obj._resource_handle = new_resource # pylint: enable=protected-access resource_map[obj.resource_handle] = new_resource self.captured_tensor_node_ids[obj.resource_handle] = node_id elif (ds_values.is_distributed_variable(obj) or resource_variable_ops.is_resource_variable(obj)): obj_to_copy = obj.primary if ds_values.is_distributed_variable( obj) else obj new_variable = resource_variable_ops.copy_to_graph_uninitialized( obj_to_copy) if ds_values.is_distributed_variable(obj): self.captured_tensor_node_ids[obj] = node_id for v in obj.values: object_map[v] = new_variable resource_map[v.handle] = new_variable.handle self.captured_tensor_node_ids[v.handle] = node_id object_map[obj] = new_variable resource_map[obj.handle] = new_variable.handle self.captured_tensor_node_ids[obj.handle] = node_id elif isinstance(obj, tracking.Asset): _process_asset(obj, asset_info, resource_map) self.captured_tensor_node_ids[obj.asset_path] = node_id # Note: some concrete functions can have been realized when tracing other # functions, and might closure-capture tensors from their parent functions. # This is normal, but it means those concrete functions can't be serialized # as their own independent endpoints, so we filter them out here. bad_functions = [] for concrete_function in self.concrete_functions: if not concrete_function.graph.saveable: raise ValueError(( "Unable to save function {name} for the following reason(s):\n" + "\n".join(concrete_function.graph.saving_errors)).format( name=concrete_function.name)) for capture in concrete_function.captured_inputs: if (tensor_util.is_tensor(capture) and capture.dtype not in _UNCOPIABLE_DTYPES and capture not in self.captured_tensor_node_ids): if hasattr(capture, "_cached_variable"): if concrete_function not in self.wrapped_functions: wrapped = self.wrapped_functions[ concrete_function] = ( function_serialization. wrap_cached_variables(concrete_function)) self.function_name_map[compat.as_text( concrete_function.name)] = (compat.as_text( wrapped.name)) continue capture_constant_value = tensor_util.constant_value( capture) if capture_constant_value is None: bad_functions.append(concrete_function) continue copied_tensor = constant_op.constant( capture_constant_value) node_id = len(self.nodes) node = _CapturedConstant(eager_tensor=capture, graph_tensor=copied_tensor) self.nodes.append(node) self.node_ids[capture] = node_id self.node_ids[node] = node_id self.captured_tensor_node_ids[capture] = node_id resource_map[capture] = copied_tensor self.concrete_functions = [ self.wrapped_functions.get(x, x) for x in self.concrete_functions if x not in bad_functions ] return object_map, resource_map, asset_info