def _setup_function_captures(self, concrete_function_name, nodes): """Setup captures and variables in a restored function.""" self._restored_concrete_functions.add(concrete_function_name) concrete_function = self._concrete_functions[concrete_function_name] proto = self._proto.concrete_functions[concrete_function_name] bound_inputs = [ self._get_tensor_from_node(nodes[node_id]) for node_id in proto.bound_inputs] bound_variables = [ nodes[node_id] for node_id in proto.bound_inputs if self._proto.nodes[node_id].WhichOneof("kind") == "variable" ] # TODO(b/205010575): This is only injecting the captured inputs into the # concrete function, note that we did not modify the FuncGraph # itself. captured_inputs_list = [] concrete_function._func_graph.variables = bound_variables # pylint: disable=protected-access if bound_inputs: for bound_input, internal_capture in zip( bound_inputs, concrete_function.inputs[-len(bound_inputs):]): if distribute_utils.is_distributed_variable(bound_input): concrete_function.graph.capture_distributed_variable( bound_input, internal_capture) captured_inputs_list.append(bound_input) elif distribute_utils.is_distributed_table(bound_input): closure, spec = bound_input.resource_handle_call_time_value() concrete_function.graph.replace_capture_with_deferred_capture( bound_input._coordinator_instance.resource_handle, # pylint: disable=protected-access closure, spec, default_value=bound_input._coordinator_instance.resource_handle, # pylint: disable=protected-access placeholder=internal_capture) captured_inputs_list.append( concrete_function.graph.deferred_external_captures[-1]) else: captured_inputs_list.append(bound_input) concrete_function.graph.replace_capture(bound_input, internal_capture) if internal_capture.dtype == dtypes.resource: if resource_variable_ops.is_resource_variable(bound_input): try: handle = bound_input.handle except ValueError: # For mirrored variables we'll copy handle data for components # as they get captured. pass else: handle_data_util.copy_handle_data(handle, internal_capture) else: handle_data_util.copy_handle_data(bound_input, internal_capture) # Setting "captures" first means "capture" won't create a new # placeholder for this input. concrete_function.graph.capture(bound_input) concrete_function.set_external_captures(captured_inputs_list)
def _get_tensor_from_node(self, node): """Resolves a node id into a tensor to be captured for a function.""" with ops.init_scope(): if distribute_utils.is_distributed_variable(node): return node elif distribute_utils.is_distributed_table(node): return node elif resource_variable_ops.is_resource_variable(node): return node.handle elif isinstance(node, tracking.Asset): return node.asset_path elif tensor_util.is_tf_type(node): return node elif isinstance(node, tracking.CapturableResource): # Note: this executes restored functions in the CapturableResource. return node.resource_handle raise ValueError(f"Cannot convert node {node} to tensor.")
def _get_tensor_from_node(self, node_id, fn_name): """Resolves a node id into a tensor to be captured for a function.""" if self._node_filters is not None and self._nodes[node_id] is None: raise ValueError( f"Error when processing nodes_to_load. Function '{fn_name}' requires " "inputs/variables that are not loaded when nodes_to_load=" f"{self._node_filters}.") with ops.init_scope(): obj = self._nodes[node_id] if distribute_utils.is_distributed_variable(obj): return obj elif distribute_utils.is_distributed_table(obj): return obj elif resource_variable_ops.is_resource_variable(obj): return obj.handle elif isinstance(obj, tracking.Asset): return obj.asset_path elif tensor_util.is_tf_type(obj): return obj elif isinstance(obj, tracking.CapturableResource): # Note: this executes restored functions in the CapturableResource. return obj.resource_handle raise ValueError(f"Cannot convert node {obj} to tensor.")