def restore_captures(concrete_function, inputs): """Restore captures for the concrete function. Used at deserialization time. For functions that are being deserialized, saved model restores objects that tensors were captured from, but functions only know about their tensors -- object information is destroyed by tracing. This additional logic extracts the tensors which the function originally captured. Args: concrete_function: the concrete function for which to restore captures inputs: a list tensors or other Python objects (such as variables) which contain tensors that were originally captured by the function """ bound_inputs = [get_tensor_from_node(obj) for obj in inputs] bound_variables = [ obj for obj in inputs if isinstance(obj, (variables_lib.Variable, resource_variable_ops.BaseResourceVariable)) ] # TODO(b/205010575): This is only injecting the captured inputs into the # concrete function, note that we did not modify the FuncGraph # itself. captured_inputs_list = [] concrete_function.set_variables(bound_variables) if bound_inputs: for bound_input, internal_capture in zip( bound_inputs, concrete_function.inputs[-len(bound_inputs):]): # Distributed inputs have special logic for capturing, so we call their # custom restoration methods if hasattr(bound_input, "__tf_experimental_restore_capture__"): captured_inputs_list.append( bound_input.__tf_experimental_restore_capture__( concrete_function, internal_capture)) else: captured_inputs_list.append(bound_input) concrete_function.graph.replace_capture(bound_input, internal_capture) if internal_capture.dtype == dtypes.resource: if resource_variable_ops.is_resource_variable(bound_input): try: handle = bound_input.handle except ValueError: # For mirrored variables we'll copy handle data for components # as they get captured. pass else: handle_data_util.copy_handle_data(handle, internal_capture) else: # TODO(b/213451747): Remove need to call copy_handle_data handle_data_util.copy_handle_data(bound_input, internal_capture) # Setting "captures" first means "capture" won't create a new # placeholder for this input. concrete_function.graph.capture(bound_input) if any([inp is None for inp in captured_inputs_list]): warnings.warn("Trying to load ShardedVariables using tf.saved_model.load. " "This won't work if using a tf.distribute.Strategy, and may " "use excess memory if not using a Strategy. Ignore this " "warning if using tf.keras.models.load_model.") concrete_function.set_external_captures(captured_inputs_list)
def _setup_function_captures(self, concrete_function_name, nodes): """Setup captures and variables in a restored function.""" self._restored_concrete_functions.add(concrete_function_name) concrete_function = self._concrete_functions[concrete_function_name] proto = self._proto.concrete_functions[concrete_function_name] bound_inputs = [ self._get_tensor_from_node(nodes[node_id]) for node_id in proto.bound_inputs] bound_variables = [ nodes[node_id] for node_id in proto.bound_inputs if self._proto.nodes[node_id].WhichOneof("kind") == "variable" ] # TODO(b/205010575): This is only injecting the captured inputs into the # concrete function, note that we did not modify the FuncGraph # itself. captured_inputs_list = [] concrete_function._func_graph.variables = bound_variables # pylint: disable=protected-access if bound_inputs: for bound_input, internal_capture in zip( bound_inputs, concrete_function.inputs[-len(bound_inputs):]): if distribute_utils.is_distributed_variable(bound_input): concrete_function.graph.capture_distributed_variable( bound_input, internal_capture) captured_inputs_list.append(bound_input) elif distribute_utils.is_distributed_table(bound_input): closure, spec = bound_input.resource_handle_call_time_value() concrete_function.graph.replace_capture_with_deferred_capture( bound_input._coordinator_instance.resource_handle, # pylint: disable=protected-access closure, spec, default_value=bound_input._coordinator_instance.resource_handle, # pylint: disable=protected-access placeholder=internal_capture) captured_inputs_list.append( concrete_function.graph.deferred_external_captures[-1]) else: captured_inputs_list.append(bound_input) concrete_function.graph.replace_capture(bound_input, internal_capture) if internal_capture.dtype == dtypes.resource: if resource_variable_ops.is_resource_variable(bound_input): try: handle = bound_input.handle except ValueError: # For mirrored variables we'll copy handle data for components # as they get captured. pass else: handle_data_util.copy_handle_data(handle, internal_capture) else: handle_data_util.copy_handle_data(bound_input, internal_capture) # Setting "captures" first means "capture" won't create a new # placeholder for this input. concrete_function.graph.capture(bound_input) concrete_function.set_external_captures(captured_inputs_list)
def _setup_functions_captures(self): """Setup captures and variables in restored functions.""" concrete_functions = sorted(self._proto.concrete_functions.items()) for name, proto in concrete_functions: concrete_function = self._concrete_functions[name] bound_inputs = [ self._get_tensor_from_node(node_id, name) for node_id in proto.bound_inputs ] bound_variables = [ self._nodes[node_id] for node_id in proto.bound_inputs if self._proto.nodes[node_id].WhichOneof("kind") == "variable" ] # TODO(andresp): This is only injecting the captured inputs into the # concrete function, note that we did not modify the FuncGraph # itself. concrete_function._captured_inputs = bound_inputs # pylint: disable=protected-access concrete_function._func_graph.variables = bound_variables # pylint: disable=protected-access if bound_inputs: for bound_input, internal_capture in zip( bound_inputs, concrete_function.inputs[-len(bound_inputs):]): if distribute_utils.is_distributed_variable(bound_input): concrete_function.graph.capture_distributed_variable( bound_input, internal_capture) else: concrete_function.graph.replace_capture( bound_input, internal_capture) if internal_capture.dtype == dtypes.resource: if resource_variable_ops.is_resource_variable( bound_input): try: handle = bound_input.handle except ValueError: # For mirrored variables we'll copy handle data for components # as they get captured. pass else: handle_data_util.copy_handle_data( handle, internal_capture) else: handle_data_util.copy_handle_data( bound_input, internal_capture) # Setting "captures" first means "capture" won't create a new # placeholder for this input. concrete_function.graph.capture(bound_input)
def tensor_list_set_item(input_handle, index, item, resize_if_index_out_of_bounds=False, name=None): """Sets `item` at `index` in input list.""" if resize_if_index_out_of_bounds: input_list_size = gen_list_ops.tensor_list_length(input_handle) # TODO(srbs): This could cause some slowdown. Consider fusing resize # functionality in the SetItem op. input_handle = control_flow_ops.cond( index >= input_list_size, lambda: gen_list_ops.tensor_list_resize( # pylint: disable=g-long-lambda input_handle, index + 1), lambda: input_handle) output_handle = gen_list_ops.tensor_list_set_item( input_handle=input_handle, index=index, item=item, name=name) handle_data_util.copy_handle_data(input_handle, output_handle) return output_handle
def tensor_list_scatter(tensor, indices, element_shape=None, input_handle=None, name=None): """Returns a TensorList created or updated by scattering `tensor`.""" tensor = ops.convert_to_tensor(tensor) if input_handle is not None: output_handle = gen_list_ops.tensor_list_scatter_into_existing_list( input_handle=input_handle, tensor=tensor, indices=indices, name=name) handle_data_util.copy_handle_data(input_handle, output_handle) return output_handle else: output_handle = gen_list_ops.tensor_list_scatter_v2( tensor=tensor, indices=indices, element_shape=_build_element_shape(element_shape), num_elements=-1, name=name) _set_handle_data(output_handle, element_shape, tensor.dtype) return output_handle
def _graph_mode_decorator(f, args, kwargs): """Implement custom gradient decorator for graph mode.""" # TODO(rsepassi): Add support for kwargs if kwargs: raise ValueError( "The custom_gradient decorator currently supports keywords " "arguments only when eager execution is enabled.") name = generate_name() args = nest.map_structure(ops.convert_to_tensor, args) # Checking global and local variables attempts to ensure that no non-resource # Variables are added to the graph. current_var_scope = variable_scope.get_variable_scope() before_vars = set([ v.ref() for v in current_var_scope.global_variables() + current_var_scope.local_variables() ]) with tape_lib.VariableWatcher() as variable_watcher: result, grad_fn = f(*args) args = nest.flatten(args) flat_result = nest.flatten(result) flat_result_len = len(flat_result) after_vars = set([ v.ref() for v in current_var_scope.global_variables() + current_var_scope.local_variables() ]) new_vars = after_vars - before_vars new_vars_list = [v.deref() for v in new_vars] for v in new_vars_list: if not resource_variable_ops.is_resource_variable(v): raise TypeError( "All variables used by a function wrapped with @custom_gradient must " "be `ResourceVariable`s. Ensure that no `variable_scope` is created " "with `use_resource=False`.") # The variables that grad_fn needs to return gradients for are the set of # variables used that are *not* part of the inputs. variables_in_tape = frozenset( [v.ref() for v in variable_watcher.watched_variables()]) graphs = {getattr(o, "graph", None) for o in flat_result} # Not all results may be tensors. However, we want to ensure all tensor # outputs are from the same graph and get a list of captured inputs for # variable search graphs.discard(None) # Discard non-graph outputs if graphs: if len(graphs) > 1: raise ValueError( "All custom_gradient outputs should be from the same graph") output_graph = graphs.pop() filtered_input_tensors = [] for i in args: if i.graph == output_graph: filtered_input_tensors.append(i) else: filtered_input_tensors = args variables_in_subgraph = frozenset([ v.ref() for v in _get_dependent_variables(input_ops=filtered_input_tensors, output_ops=flat_result) ]) variables = sorted( [v.deref() for v in variables_in_subgraph.union(variables_in_tape)], key=lambda v: v.name) grad_argspec = tf_inspect.getfullargspec(grad_fn) variables_in_signature = ("variables" in grad_argspec.args or "variables" in grad_argspec.kwonlyargs or grad_argspec.varkw) if variables and not variables_in_signature: raise TypeError( "@tf.custom_gradient grad_fn must accept keyword argument 'variables', " "since function uses variables: {}".format(variables)) if variables_in_signature and not variables: # User seems to intend to use variables but none were captured. logging.warning( "@custom_gradient grad_fn has 'variables' in signature, but " "no ResourceVariables were used on the forward pass.") all_tensors = flat_result + args + variables def tape_grad_fn(*result_grads): """Custom grad fn wrapper.""" result_grads = result_grads[:flat_result_len] if variables: input_grads, variable_grads = grad_fn(*result_grads, variables=variables) if len(variable_grads) != len(variables): raise ValueError("Must return gradient for each variable from " "@custom_gradient grad_fn.") else: input_grads = grad_fn(*result_grads) variable_grads = [] # Need to return one value per input to the IdentityN, so pad the # gradients of the inputs of the custom_gradient function with the # gradients of the outputs as well. input_grads = nest.flatten(input_grads) return ([None] * flat_result_len) + input_grads + variable_grads @ops.RegisterGradient(name) def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable """Custom grad fn wrapper.""" return tape_grad_fn(*result_grads) original_tensors = all_tensors with ops.get_default_graph().gradient_override_map({"IdentityN": name}): all_tensors = array_ops.identity_n(all_tensors) original_tensors = [ops.convert_to_tensor(x) for x in original_tensors] # Propagate handle data for happier shape inference for resource variables. for i, t in enumerate(original_tensors): if t.dtype == dtypes.resource and hasattr(t, "_handle_data"): all_tensors[i]._handle_data = t._handle_data # pylint: disable=protected-access tape_lib.record_operation(f.__name__, all_tensors, original_tensors, tape_grad_fn) for ot, t in zip(original_tensors, all_tensors): handle_data_util.copy_handle_data(ot, t) return nest.pack_sequence_as(structure=result, flat_sequence=all_tensors[:flat_result_len])