def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False): """Captures a Tensor while building a graph mode function. Arguments: value: A Tensor object. dtype: The datatype of the value produced by the node in the graph. name: Name of the node in the graph. as_ref: Ignored (required by register_tensor_conversion_function). Returns: Returns a constant (the current value of the tensor) if capturing is not enabled. A placeholder which will have the value of the tensor at runtime otherwise. """ if context.in_eager_mode(): return value _ = as_ref tensor_map = _scoped_captures.tensors if tensor_map is None: # Capturing is not enabled. return constant_op.constant(value.numpy()) captured_value = tensor_map.get(ops.tensor_id(value), None) if captured_value is None: captured_value = graph_placeholder( dtype=dtype or value.dtype, shape=value.shape, name=name) if captured_value.dtype == dtypes.resource: captured_value._handle_data = value._handle_data # pylint: disable=protected-access tensor_map[ops.tensor_id(value)] = (value, captured_value) else: captured_value = captured_value[1] tape.record_operation("captured_value", [captured_value], [value], [], lambda x: x) return captured_value
def capture_value(tensor_map, value, dtype, name): """Capture a value from outside the function, to pass in as an extra arg.""" captured_value = tensor_map.get(ops.tensor_id(value), None) if captured_value is None: captured_value = graph_placeholder( dtype=dtype or value.dtype, shape=value.shape, name=name) if captured_value.dtype == dtypes_module.resource: handle_data = value._handle_data # pylint: disable=protected-access captured_value._handle_data = handle_data # pylint: disable=protected-access if handle_data is not None and handle_data.is_set: # Ensure that shapes and dtypes are propagated. shapes, types = zip(*[(pair.shape, pair.dtype) for pair in handle_data.shape_and_type]) ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] shapes = [[d.size for d in s.dim] if not s.unknown_rank else None for s in shapes] with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( captured_value._op._graph._c_graph, # pylint: disable=protected-access captured_value._as_tf_output(), # pylint: disable=protected-access shapes, ranks, types, status) tensor_map[ops.tensor_id(value)] = (value, captured_value) else: captured_value = captured_value[1] tape.record_operation("captured_value", [captured_value], [value], lambda x: [x]) return captured_value
def _capture_helper(self, tensor, name): captured_tensor = self.captures.get(tensor, None) if captured_tensor is None: captured_tensor = _create_substitute_placeholder(tensor, name=name, dtype=tensor.dtype) self.captures[tensor] = captured_tensor self.inputs.append(captured_tensor) tape.record_operation("captured_value", [captured_tensor], [tensor], lambda x: [x]) return captured_tensor
def _record_gradient(op_name, inputs, attrs, results, ctx, name): """Records gradients for a TensorFlow operation. Args: op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to execute. inputs: A flat list of Tensor object inputs to the operation. attrs: A tuple with alternating string attr names and attr values for this operation. results: The results of the operation (as a flat list). ctx: The value of context.context(). name: Customized name for the operation. Returns: A list of maybe-wrapped results. Either Tensors or TensorNodes. Raises: An exception on error. """ if not tape.could_possibly_record(): return if op_name in _ops_which_dont_need_outputs: op_outputs = None else: # TODO(apassos) this line creates a weak circular reference where the # backprop function keeps an output alive which in turn keeps the tape entry # alive which keeps the backprop function alive. Figure out how to break # this up without breaking second derivatives of ops like Exp whose # gradients depend only on the outputs. op_outputs = results if op_name in _ops_which_dont_need_inputs: op_inputs = None else: op_inputs = inputs num_inputs = len(inputs) def grad_fn(*orig_outputs): """Generated gradient function.""" result = _magic_gradient_function(op_name, attrs, num_inputs, op_inputs, op_outputs, orig_outputs) if _tracing: print("Gradient for", (name if name else op_name), "inputs", op_inputs, "output_grads", orig_outputs, "gradients", result) return result inputs = [ops.internal_convert_to_tensor(x, ctx=ctx) for x in inputs] tape.record_operation(op_name, results, inputs, [], grad_fn) if _tracing: print("Computed op", (name if name else op_name), "inputs", inputs, "outputs", results)
def capture_value(tensor_map, value, dtype, name): """Capture a value from outside the function, to pass in as an extra arg.""" captured_value = tensor_map.get(ops.tensor_id(value), None) if captured_value is None: captured_value = graph_placeholder( dtype=dtype or value.dtype, shape=value.shape, name=name) if captured_value.dtype == dtypes.resource: captured_value._handle_data = value._handle_data # pylint: disable=protected-access tensor_map[ops.tensor_id(value)] = (value, captured_value) else: captured_value = captured_value[1] tape.record_operation("captured_value", [captured_value], [value], lambda x: [x]) return captured_value
def decorated(*args, **kwargs): """Decorated function with custom gradient.""" if context.in_graph_mode(): if kwargs: raise ValueError( "custom_gradient in graph mode doesn't support keyword arguments.") name = "CustomGradient-%s" % tf_ops.uid() args = [tf_ops.convert_to_tensor(x) for x in args] result, grad_fn = f(*args) flat_result = nest.flatten(result) all_tensors = flat_result + args @tf_ops.RegisterGradient(name) def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable gradients = nest.flatten(grad_fn(*result_grads[:len(flat_result)])) # Need to return one value per input to the IdentityN, so pad the # gradients of the inputs of the custom_gradient function with the # gradients of the outputs as well. return ([None] * len(flat_result)) + gradients with tf_ops.get_default_graph().gradient_override_map( {"IdentityN": name}): all_tensors = array_ops.identity_n(all_tensors) return nest.pack_sequence_as( structure=result, flat_sequence=all_tensors[:len(flat_result)]) input_tensors = [x for x in args if isinstance(x, tf_ops.Tensor)] with tape.stop_recording(): result, grad_fn = f(*args, **kwargs) # TODO(apassos): naive uses of custom_gradient will not get the correct # second derivative this way if they capture any output tensors. Change the # signature of custom_gradient. def actual_grad_fn(*outputs): return grad_fn(*outputs) flat_result = nest.flatten(result) tape.record_operation( f.__name__, flat_result, input_tensors, [], actual_grad_fn) flat_result = list(flat_result) return result
def decorated(*args, **kwargs): """Decorated function with custom gradient.""" if context.in_graph_mode(): if kwargs: raise ValueError( "custom_gradient in graph mode doesn't support keyword arguments.") name = "CustomGradient-%s" % tf_ops.uid() args = [tf_ops.convert_to_tensor(x) for x in args] result, grad_fn = f(*args) flat_result = nest.flatten(result) all_tensors = flat_result + args @tf_ops.RegisterGradient(name) def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable gradients = nest.flatten(grad_fn(*result_grads[:len(flat_result)])) # Need to return one value per input to the IdentityN, so pad the # gradients of the inputs of the custom_gradient function with the # gradients of the outputs as well. return ([None] * len(flat_result)) + gradients with tf_ops.get_default_graph().gradient_override_map( {"IdentityN": name}): all_tensors = array_ops.identity_n(all_tensors) return nest.pack_sequence_as( structure=result, flat_sequence=all_tensors[:len(flat_result)]) input_tensors = [tf_ops.convert_to_tensor(x) for x in args] result, grad_fn = f(*args, **kwargs) flat_result = nest.flatten(result) # TODO(apassos) consider removing the identity below. flat_result = [gen_array_ops.identity(x) for x in flat_result] def actual_grad_fn(*outputs): return nest.flatten(grad_fn(*outputs)) tape.record_operation( f.__name__, flat_result, input_tensors, actual_grad_fn) flat_result = list(flat_result) return nest.pack_sequence_as(result, flat_result)
def _backprop_call(self, args): """Calls the wrapped function and records the result on a tape.""" all_args = args + self._extra_inputs signature = self._forward_fdef.definition.signature ctx = context.context() if ctx.in_graph_mode(): g = ops.get_default_graph() g._add_function(self._forward_fdef) # pylint: disable=protected-access def make_tensor(x): if isinstance(x, ops.Tensor): return x return ops.internal_convert_to_tensor(x, ctx=ctx) op = g.create_op( signature.name, [make_tensor(x) for x in all_args], [dtypes.DType(x.type) for x in signature.output_arg], op_def=signature, name="FunctionCall", compute_shapes=False) outputs = op.outputs outputs = [outputs] if isinstance( outputs, (ops.Tensor, type(None))) else list(outputs) for i, s in enumerate(self._output_shapes): outputs[i].set_shape(s) else: outputs = execute.execute( str(signature.name), num_outputs=len(signature.output_arg), inputs=all_args, attrs=None, ctx=ctx) real_outputs = outputs[:len(self._returns)] side_outputs = outputs[len(self._returns):] def backward_function(*args): return self._backward_function(*(list(args) + side_outputs)) tape.record_operation( signature.name, real_outputs, (args + self._extra_inputs), backward_function) return self._build_call_outputs(real_outputs)
def _eager_mode_decorator(f, *args, **kwargs): """Implement custom gradient decorator for eager mode.""" with backprop.GradientTape() as tape: result, grad_fn = f(*args, **kwargs) all_inputs = list(args) + list(kwargs.values()) # The variables that grad_fn needs to return gradients for are the set of # variables used that are *not* part of the inputs. variables = [v for v in set(tape.watched_variables()) if v not in all_inputs] grad_argspec = tf_inspect.getfullargspec(grad_fn) if (variables and ("variables" not in grad_argspec.args) and not grad_argspec.varkw): raise TypeError("If using @custom_gradient with a function that " "uses variables, then grad_fn must accept a keyword " "argument 'variables'.") flat_result = nest.flatten(result) # TODO(apassos) consider removing the identity below. flat_result = [gen_array_ops.identity(x) for x in flat_result] input_tensors = [ops.convert_to_tensor(x) for x in list(args) + list(variables)] arg_count = len(args) def actual_grad_fn(*result_grads): """Custom grad fn wrapper.""" if variables: input_grads, variable_grads = grad_fn(*result_grads, variables=variables) if len(variable_grads) != len(variables): raise ValueError("Must return gradient for each variable from " "@custom_gradient grad_fn.") else: input_grads = grad_fn(*result_grads) variable_grads = [] flat_grads = nest.flatten(input_grads) if len(flat_grads) != arg_count: raise ValueError( "custom_gradient function expected to return", arg_count, "gradients but returned", len(flat_grads), "instead.") return nest.flatten(input_grads) + variable_grads tape_lib.record_operation(f.__name__, flat_result, input_tensors, actual_grad_fn) flat_result = list(flat_result) return nest.pack_sequence_as(result, flat_result)
def capture_value(tensor_map, value, dtype, name): """Capture a value from outside the function, to pass in as an extra arg.""" captured_value = tensor_map.get(ops.tensor_id(value), None) if captured_value is None: captured_value = graph_placeholder( dtype=dtype or value.dtype, shape=value.shape, name=name) if captured_value.dtype == dtypes_module.resource: if ops._USE_C_SHAPES: # pylint: disable=protected-access if isinstance(value, ops.EagerTensor): handle_data = value._handle_data # pylint: disable=protected-access else: handle_data = resource_variable_ops.get_resource_handle_data(value) else: handle_data = value._handle_data # pylint: disable=protected-access if handle_data is not None and handle_data.is_set: # pylint: disable=protected-access if ops._USE_C_SHAPES: pywrap_tensorflow.SetResourceHandleShapeAndType( captured_value.graph._c_graph, captured_value._as_tf_output(), handle_data.SerializeToString()) else: captured_value._handle_data = handle_data # pylint: enable=protected-access # Ensure that shapes and dtypes are propagated. shapes, types = zip(*[(pair.shape, pair.dtype) for pair in handle_data.shape_and_type]) ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] shapes = [[d.size for d in s.dim] if not s.unknown_rank else None for s in shapes] pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( captured_value._op._graph._c_graph, # pylint: disable=protected-access captured_value._as_tf_output(), # pylint: disable=protected-access shapes, ranks, types) tensor_map[ops.tensor_id(value)] = (value, captured_value) else: captured_value = captured_value[1] tape.record_operation("captured_value", [captured_value], [value], lambda x: [x]) return captured_value
def decorated(*args, **kwargs): """Decorated function with custom gradient.""" input_tensors = [x for x in args if isinstance(x, tf_ops.Tensor)] with tape.stop_recording(): result, grad_fn = f(*args, **kwargs) # TODO(apassos): naive uses of custom_gradient will not get the correct # second derivative this way if they capture any output tensors. Change the # signature of custom_gradient. def actual_grad_fn(*outputs): return grad_fn(*outputs) flat_result = nest.flatten(result) tape.record_operation( flat_result, input_tensors, [], actual_grad_fn) flat_result = list(flat_result) return result
def decorated(*args, **kwargs): """Decorated function with custom gradient.""" input_tensors = [_watch_value_from_tape(x) for x in args if isinstance(x, (_tensor.Tensor, tf_ops.Tensor)) or ag_core.isnode(x)] result, grad_fn = f(*args, **kwargs) flat_result = nest.flatten(result) flat_result = [ag_core.getval(x) for x in flat_result] flat_result = tape.record_operation( flat_result, input_tensors, [], grad_fn) flat_result = list(flat_result) return nest.pack_sequence_as(structure=result, flat_sequence=flat_result)
def _record_gradient(op_name, inputs, attrs, results, name): """Records gradients for a TensorFlow operation. Args: op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to execute. inputs: A flat list of Tensor object inputs to the operation. attrs: A tuple with alternating string attr names and attr values for this operation. results: The results of the operation (as a flat list). name: Customized name for the operation. Returns: A list of maybe-wrapped results. Either Tensors or TensorNodes. Raises: An exception on error. """ if not any(ag_core.isnode(x) for x in inputs): return results num_outputs = len(results) if num_outputs == 0: return results if attrs is not None: attrs = tuple(tuple(x) if isinstance(x, list) else x for x in attrs) # It is imperative we make a copy of results here as otherwise we create a # dependency cycle in the captured function and this can delay garbage # collecting of the tensors arbitrarily. results_size = len(results) if isinstance(results, (list, tuple)) else 1 def grad_fn(*orig_outputs): """Generated gradient function.""" tensors = inputs + list(orig_outputs) tensors = container_types.make_sequence(tape.EagerList, *tensors) result = _magic_gradient_function(op_name, attrs, len(inputs), num_outputs, *(tensors)) if _tracing: print("Gradient for", (name if name else op_name), "inputs", inputs, "output_grads", orig_outputs[results_size:], "gradients", result) return result results = tape.record_operation(results, inputs, [], grad_fn) if _tracing: print("Computed op", (name if name else op_name), "inputs", inputs, "outputs", results) return results
def _backprop_call(self, args): """Calls the wrapped function and records the result on a tape.""" all_args = args + self._extra_inputs signature = self._forward_fdef.definition.signature if context.in_graph_mode(): g = ops.get_default_graph() g._add_function(self._forward_fdef) # pylint: disable=protected-access unwrapped_args = [ag_core.getval(x) for x in all_args] op = g.create_op( signature.name, [ops.convert_to_tensor(x) for x in unwrapped_args], [dtypes.DType(x.type) for x in signature.output_arg], op_def=signature, name="FunctionCall", compute_shapes=False) outputs = op.outputs outputs = [outputs] if isinstance( outputs, (tensor.Tensor, ops.Tensor, type(None))) else list(outputs) for i, s in enumerate(self._output_shapes): outputs[i].set_shape(s) else: outputs = execute.execute( signature.name, num_outputs=len(signature.output_arg), inputs=all_args) real_outputs = outputs[:len(self._returns)] side_outputs = outputs[len(self._returns):] watched_extra_inputs = [] for t in self._extra_inputs: tid = ops.tensor_id(t) for t in tape._tape_stack.stack: # pylint: disable=protected-access w = t.value.tensors.get(tid, None) if w is not None: watched_extra_inputs.append(w) break else: # Note: for-else here done on purpose watched_extra_inputs.append(t) def backward_function_wrapper(*outputs): outputs = outputs[len(real_outputs):] return self._backward_function(*outputs) real_outputs = tape.record_operation( real_outputs, (args + watched_extra_inputs), side_outputs, backward_function_wrapper) return self._build_call_outputs(self._returns, real_outputs)
def testSpecialForwardFunctionUsed(self): c = constant_op.constant(1.) d = constant_op.constant(2.) e = constant_op.constant(3.) with forwardprop.ForwardAccumulator(c, 10.) as acc: tape_lib.record_operation("ForwardIsSpecial", [d], [c], None, lambda jvp: [-2. * jvp]) self.assertAllClose(-20., acc.jvp(d)) tape_lib.record_operation("ForwardIsSpecial2", [], [], None, lambda: []) tape_lib.record_operation("ForwardIsSpecial3", [e], [d], None, lambda x: [x]) self.assertAllClose(-20., acc.jvp(e))
def _backprop_call(self, args): """Calls the wrapped function and records the result on a tape.""" all_args = args + self._extra_inputs signature = self._forward_fdef.definition.signature if context.in_graph_mode(): g = ops.get_default_graph() g._add_function(self._forward_fdef) # pylint: disable=protected-access unwrapped_args = [ag_core.getval(x) for x in all_args] op = g.create_op( signature.name, [ops.convert_to_tensor(x) for x in unwrapped_args], [dtypes.DType(x.type) for x in signature.output_arg], op_def=signature, name="FunctionCall", compute_shapes=False) outputs = op.outputs outputs = [outputs] if isinstance(outputs, (tensor.Tensor, ops.Tensor, type(None))) else list(outputs) for i, s in enumerate(self._output_shapes): outputs[i].set_shape(s) else: outputs = execute.execute(signature.name, num_outputs=len(signature.output_arg), inputs=all_args) real_outputs = outputs[:len(self._returns)] side_outputs = outputs[len(self._returns):] watched_extra_inputs = [] for t in self._extra_inputs: tid = ops.tensor_id(t) for t in tape._tape_stack.stack: # pylint: disable=protected-access w = t.value.tensors.get(tid, None) if w is not None: watched_extra_inputs.append(w) break else: # Note: for-else here done on purpose watched_extra_inputs.append(t) real_outputs = tape.record_operation(real_outputs, (args + watched_extra_inputs), side_outputs, self._backward_function) return self._build_call_outputs(self._returns, real_outputs)
def decorated(*args, **kwargs): """Decorated function with custom gradient.""" input_tensors = [ _watch_value_from_tape(x) for x in args if isinstance(x, (_tensor.Tensor, tf_ops.Tensor)) or ag_core.isnode(x) ] result, grad_fn = f(*args, **kwargs) result_size = len(result) if isinstance(result, (list, tuple)) else 1 # TODO(apassos): naive uses of custom_gradient will not get the correct # second derivative this way if they capture any output tensors. Change the # signature of custom_gradient. def actual_grad_fn(*outputs): outputs = outputs[result_size:] return grad_fn(*outputs) flat_result = nest.flatten(result) flat_result = [ag_core.getval(x) for x in flat_result] flat_result = tape.record_operation(flat_result, input_tensors, [], actual_grad_fn) flat_result = list(flat_result) return nest.pack_sequence_as(structure=result, flat_sequence=flat_result)
def decorated(*args, **kwargs): """Decorated function with custom gradient.""" input_tensors = [_watch_value_from_tape(x) for x in args if isinstance(x, (_tensor.Tensor, tf_ops.Tensor)) or ag_core.isnode(x)] result, grad_fn = f(*args, **kwargs) result_size = len(result) if isinstance(result, (list, tuple)) else 1 # TODO(apassos): naive uses of custom_gradient will not get the correct # second derivative this way if they capture any output tensors. Change the # signature of custom_gradient. def actual_grad_fn(*outputs): outputs = outputs[result_size:] return grad_fn(*outputs) flat_result = nest.flatten(result) flat_result = [ag_core.getval(x) for x in flat_result] flat_result = tape.record_operation( flat_result, input_tensors, [], actual_grad_fn) flat_result = list(flat_result) return nest.pack_sequence_as(structure=result, flat_sequence=flat_result)
def _graph_mode_decorator(f, args, kwargs): """Implement custom gradient decorator for graph mode.""" # TODO(rsepassi): Add support for kwargs if kwargs: raise ValueError( "The custom_gradient decorator currently supports keywords " "arguments only when eager execution is enabled.") name = "CustomGradient-%s" % ops.uid() default_graph = ops.get_default_graph() def convert_arg(x): x = ops.convert_to_tensor(x) # If graph building, be sure to capture all inputs if default_graph.building_function and x.graph != default_graph: x = default_graph.capture(x) return x args = nest.map_structure(convert_arg, args) # Checking global and local variables attempts to ensure that no non-resource # Variables are added to the graph. current_var_scope = variable_scope.get_variable_scope() before_vars = set([ v.ref() for v in current_var_scope.global_variables() + current_var_scope.local_variables() ]) with tape_lib.VariableWatcher() as variable_watcher: result, grad_fn = f(*args) args = nest.flatten(args) flat_result = nest.flatten(result) flat_result_len = len(flat_result) after_vars = set([ v.ref() for v in current_var_scope.global_variables() + current_var_scope.local_variables() ]) new_vars = after_vars - before_vars new_vars_list = [v.deref() for v in new_vars] for v in new_vars_list: if not resource_variable_ops.is_resource_variable(v): raise TypeError( "All variables used by a function wrapped with @custom_gradient must " "be `ResourceVariable`s. Ensure that no `variable_scope` is created " "with `use_resource=False`.") # The variables that grad_fn needs to return gradients for are the set of # variables used that are *not* part of the inputs. variables_in_tape = frozenset( [v.ref() for v in variable_watcher.watched_variables()]) variables_in_subgraph = frozenset([ v.ref() for v in _get_dependent_variables(input_ops=args, output_ops=flat_result) ]) variables = list( [v.deref() for v in variables_in_subgraph.union(variables_in_tape)]) grad_argspec = tf_inspect.getfullargspec(grad_fn) variables_in_signature = ("variables" in grad_argspec.args or "variables" in grad_argspec.kwonlyargs or grad_argspec.varkw) if variables and not variables_in_signature: raise TypeError( "@tf.custom_gradient grad_fn must accept keyword argument 'variables', " "since function uses variables: {}".format(variables)) if variables_in_signature and not variables: # User seems to intend to use variables but none were captured. logging.warn( "@custom_gradient grad_fn has 'variables' in signature, but " "no ResourceVariables were used on the forward pass.") all_tensors = flat_result + args + variables def tape_grad_fn(*result_grads): """Custom grad fn wrapper.""" result_grads = result_grads[:flat_result_len] if variables: input_grads, variable_grads = grad_fn(*result_grads, variables=variables) if len(variable_grads) != len(variables): raise ValueError("Must return gradient for each variable from " "@custom_gradient grad_fn.") else: input_grads = grad_fn(*result_grads) variable_grads = [] # Need to return one value per input to the IdentityN, so pad the # gradients of the inputs of the custom_gradient function with the # gradients of the outputs as well. input_grads = nest.flatten(input_grads) return ([None] * flat_result_len) + input_grads + variable_grads @ops.RegisterGradient(name) def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable """Custom grad fn wrapper.""" return tape_grad_fn(*result_grads) original_tensors = all_tensors with ops.get_default_graph().gradient_override_map({"IdentityN": name}): all_tensors = array_ops.identity_n(all_tensors) original_tensors = [ops.convert_to_tensor(x) for x in original_tensors] # Propagate handle data for happier shape inference for resource variables. for i, t in enumerate(original_tensors): if t.dtype == dtypes.resource and hasattr(t, "_handle_data"): all_tensors[i]._handle_data = t._handle_data # pylint: disable=protected-access tape_lib.record_operation(f.__name__, all_tensors, original_tensors, tape_grad_fn) for ot, t in zip(original_tensors, all_tensors): copy_handle_data(ot, t) return nest.pack_sequence_as(structure=result, flat_sequence=all_tensors[:flat_result_len])
def _graph_mode_decorator(f, *args, **kwargs): """Implement custom gradient decorator for graph mode.""" # TODO(rsepassi): Add support for kwargs if kwargs: raise ValueError( "The custom_gradient decorator currently supports keywords " "arguments only when eager execution is enabled.") name = "CustomGradient-%s" % ops.uid() args = [ops.convert_to_tensor(x) for x in args] # Checking global and local variables attempts to ensure that no non-resource # Variables are added to the graph. current_var_scope = variable_scope.get_variable_scope() before_vars = set(current_var_scope.global_variables() + current_var_scope.local_variables()) with backprop.GradientTape() as tape: result, grad_fn = f(*args) after_vars = set(current_var_scope.global_variables() + current_var_scope.local_variables()) new_vars = after_vars - before_vars for v in new_vars: if not isinstance(v, resource_variable_ops.ResourceVariable): raise TypeError( "All variables used by a function wrapped with @custom_gradient must " "be `ResourceVariable`s. Ensure that no `variable_scope` is created " "with `use_resource=False`.") # The variables that grad_fn needs to return gradients for are the set of # variables used that are *not* part of the inputs. variables = list(set(tape.watched_variables()) - set(args)) grad_argspec = tf_inspect.getfullargspec(grad_fn) variables_in_signature = ("variables" in grad_argspec.args or grad_argspec.varkw) if variables and not variables_in_signature: raise TypeError("If using @custom_gradient with a function that " "uses variables, then grad_fn must accept a keyword " "argument 'variables'.") if variables_in_signature and not variables: # User seems to intend to use variables but none were captured. if not variable_scope.get_variable_scope().use_resource: raise TypeError("If using @custom_gradient with a function that " "uses variables, the enclosing variable scope must " "have use_resource=True.") else: logging.warn("@custom_gradient grad_fn has 'variables' in signature, but " "no ResourceVariables were used on the forward pass.") flat_result = nest.flatten(result) all_tensors = flat_result + args + variables def tape_grad_fn(*result_grads): """Custom grad fn wrapper.""" result_grads = result_grads[:len(flat_result)] if variables: input_grads, variable_grads = grad_fn(*result_grads, variables=variables) if len(variable_grads) != len(variables): raise ValueError("Must return gradient for each variable from " "@custom_gradient grad_fn.") else: input_grads = grad_fn(*result_grads) variable_grads = [] # Need to return one value per input to the IdentityN, so pad the # gradients of the inputs of the custom_gradient function with the # gradients of the outputs as well. input_grads = nest.flatten(input_grads) return ([None] * len(flat_result)) + input_grads + variable_grads @ops.RegisterGradient(name) def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable """Custom grad fn wrapper.""" return tape_grad_fn(*result_grads) original_tensors = all_tensors with ops.get_default_graph().gradient_override_map({"IdentityN": name}): all_tensors = array_ops.identity_n(all_tensors) # Propagate handle data for happier shape inference for resource variables. for i, t in enumerate(original_tensors): if t.dtype == dtypes.resource and hasattr(t, "_handle_data"): all_tensors[i]._handle_data = t._handle_data # pylint: disable=protected-access tape_lib.record_operation( f.__name__, all_tensors, original_tensors, tape_grad_fn) for ot, t in zip(original_tensors, all_tensors): copy_handle_data(ot, t) return nest.pack_sequence_as( structure=result, flat_sequence=all_tensors[:len(flat_result)])
def inner(*args, _watch_vars=None, num_checkpoints=0, **kwargs): r"""Performs a forward pass while storing only the checkpoint activations """ if _watch_vars is None: _watch_vars = [] tensor_watches = [tf.convert_to_tensor(x) for x in _watch_vars] model, x = args # Dictionary to cache the desired activations during forward pass saved_tensors = {} # index -1 represents the inputs x idx_ckpt = np.array([-1]) num_layers = len(model.layers) # Perform checkpointing. Naive scheme - just distribute checkpoints uniformly across the layers. if num_checkpoints: if num_checkpoints >= num_layers: raise ValueError( "The number of checkpoints is {} and should be less than number of" "layers in the model, which is {} .".format( num_checkpoints, num_layers)) idx_start, idx_end = 0, num_layers - 1 # Use offset to avoid checkpointing the start and end layers of the model offset = idx_end // num_checkpoints start, end = (idx_start + offset) // 2, (idx_end - offset + idx_end) // 2 idx_tmp = np.linspace(start, end, num_checkpoints, dtype=np.uint32) idx_ckpt = np.append(idx_ckpt, idx_tmp).tolist() x = tf.convert_to_tensor(x) with tape_lib.stop_recording(): # perform forward pass while caching checkpoint layer outputs result = x saved_tensors[-1] = result for idx_layer in range(num_layers): result = model.layers[idx_layer](result) if idx_layer in idx_ckpt: saved_tensors[idx_layer] = result flat_result = nest.flatten(result) flat_result = [tf.identity(x) for x in flat_result] output = nest.pack_sequence_as(result, flat_result) def grad(*grads_output): r"""Performs the backward pass while recomputing the forward pass activations for each layer. """ grads = [] for idx_forward in range(len(model.layers)): idx_back = len(model.layers) - idx_forward - 1 back_layer = model.layers[idx_back] idx_last_ckpt = idx_ckpt[-1] if idx_back <= idx_last_ckpt: idx_ckpt.pop() idx_last_ckpt = idx_ckpt[-1] prev_output = saved_tensors[idx_last_ckpt] for idx_layer in range(idx_last_ckpt + 1, idx_back): prev_output = model.layers[idx_layer](prev_output) with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(back_layer.trainable_variables) tape.watch(prev_output) recomputed_output = back_layer(prev_output) # identity necessary for grad propagation across 'dead' layers recomputed_output = [ tf.identity(x) for x in recomputed_output ] recomputed_output = tf.convert_to_tensor(recomputed_output) prev_output = nest.flatten(prev_output) sources = prev_output + back_layer.trainable_variables grads_intermediate = tape.gradient( recomputed_output, sources, output_gradients=grads_output) grads_output = grads_intermediate[:len(prev_output)] grads_vars = grads_intermediate[len(prev_output):] grads.extend(grads_vars[::-1]) del tape return grads[::-1] tape_lib.record_operation(str(f), flat_result, tensor_watches, grad) return output
def inner(*args, _checkpoint=False, _watch_vars=None, _force_seed=False, **kwargs): if _force_seed: if isinstance(_force_seed, Iterator): seed = next(_force_seed) else: seed = random.randint(1, 1 << 31) if _checkpoint: if _watch_vars is None: _watch_vars = [] watch_args = [] flat_inputs = nest.flatten(args) + nest.flatten( list(kwargs.values())) flat_inputs = [x for x in flat_inputs if tf.is_tensor(x)] flat_inputs = [x for x in flat_inputs if x.dtype == tf.float32] unique_inputs = [ x.deref() for x in set(x.experimental_ref() for x in flat_inputs) ] unique_vars = [ v.deref() for v in set(v.experimental_ref() for v in _watch_vars) if not any(v is inp for inp in flat_inputs) ] watches = unique_inputs + unique_vars tensor_watches = [tf.convert_to_tensor(x) for x in watches] with tape.stop_recording(): if _force_seed: tf.random.set_seed(seed) result = f(*args, **kwargs) flat_result = nest.flatten(result) # No idea what the point of this is but they do it in tf.custom_gradient so I'm doing it too flat_result = [tf.identity(x) for x in flat_result] output = nest.pack_sequence_as(result, flat_result) def grad(*output_grads): with tf.GradientTape() as g: g.watch(watches) if _force_seed: tf.random.set_seed(seed) recomputed_output = f(*args, **kwargs) recomputed_output = [ tf.identity(x) for x in nest.flatten(recomputed_output) ] grads = g.gradient(recomputed_output, watches, output_gradients=output_grads) del g return grads tape.record_operation(str(f), flat_result, tensor_watches, grad) return output else: if _force_seed: tf.random.set_seed(seed) return f(*args, **kwargs)
def _graph_mode_decorator(f, args, kwargs): """Implement custom gradient decorator for graph mode.""" # TODO(rsepassi): Add support for kwargs if kwargs: raise ValueError( "The custom_gradient decorator currently supports keywords " "arguments only when eager execution is enabled.") name = "CustomGradient-%s" % ops.uid() args = nest.map_structure(ops.convert_to_tensor, args) # Checking global and local variables attempts to ensure that no non-resource # Variables are added to the graph. current_var_scope = variable_scope.get_variable_scope() before_vars = set([ v.ref() for v in current_var_scope.global_variables() + current_var_scope.local_variables() ]) with tape_lib.VariableWatcher() as variable_watcher: result, grad_fn = f(*args) args = nest.flatten(args) after_vars = set([ v.ref() for v in current_var_scope.global_variables() + current_var_scope.local_variables() ]) new_vars = after_vars - before_vars new_vars_list = [v.deref() for v in new_vars] for v in new_vars_list: if not resource_variable_ops.is_resource_variable(v): raise TypeError( "All variables used by a function wrapped with @custom_gradient must " "be `ResourceVariable`s. Ensure that no `variable_scope` is created " "with `use_resource=False`.") # It is possible for the caller to pass in an input that is from a different # graph. Even though this is not valid we filter these out if they are not # from the output graph to make it easier for some code to migrate to custom # gradients. inputs = nest.flatten(args) outputs = nest.flatten(result) graphs = {getattr(o, "graph", None) for o in outputs} # Not all results may be tensors. However, we want to ensure that all outputs # are from the same graph and use that to filter the inputs. graphs.discard(None) # Discard non-graph outputs if graphs: if len(graphs) > 1: raise ValueError("All graph outputs should be from the same graph") output_graph = graphs.pop() filtered_inputs = [] for i in inputs: if i.graph != output_graph: logging.warn("%s does not belong to output graph %s", i, output_graph) else: filtered_inputs.append(i) inputs = filtered_inputs # The variables that grad_fn needs to return gradients for are the set of # variables used that are *not* part of the inputs. variables_in_tape = frozenset([ v.ref() for v in variable_watcher.watched_variables() ]) - frozenset(v.ref() for v in inputs) variables_in_subgraph = frozenset([ v.ref() for v in get_dependent_variables(input_ops=inputs, output_ops=outputs) ]) variables = list( [v.deref() for v in variables_in_subgraph.union(variables_in_tape)]) grad_argspec = tf_inspect.getfullargspec(grad_fn) variables_in_signature = ("variables" in grad_argspec.args or grad_argspec.varkw) if variables and not variables_in_signature: raise TypeError("If using @custom_gradient with a function that " "uses variables, then grad_fn must accept a keyword " "argument 'variables'.") if variables_in_signature and not variables: # User seems to intend to use variables but none were captured. logging.warn("@custom_gradient grad_fn has 'variables' in signature, but " "no ResourceVariables were used on the forward pass.") flat_result = nest.flatten(result) flat_result_len = len(flat_result) all_tensors = flat_result + inputs + variables def tape_grad_fn(*result_grads): """Custom grad fn wrapper.""" result_grads = result_grads[:flat_result_len] if variables: input_grads, variable_grads = grad_fn(*result_grads, variables=variables) if len(variable_grads) != len(variables): raise ValueError("Must return gradient for each variable from " "@custom_gradient grad_fn.") else: input_grads = grad_fn(*result_grads) variable_grads = [] # Need to return one value per input to the IdentityN, so pad the # gradients of the inputs of the custom_gradient function with the # gradients of the outputs as well. input_grads = nest.flatten(input_grads) return ([None] * flat_result_len) + input_grads + variable_grads @ops.RegisterGradient(name) def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable """Custom grad fn wrapper.""" return tape_grad_fn(*result_grads) original_tensors = all_tensors with ops.get_default_graph().gradient_override_map({"IdentityN": name}): all_tensors = array_ops.identity_n(all_tensors) original_tensors = [ops.convert_to_tensor(x) for x in original_tensors] # Propagate handle data for happier shape inference for resource variables. for i, t in enumerate(original_tensors): if t.dtype == dtypes.resource and hasattr(t, "_handle_data"): all_tensors[i]._handle_data = t._handle_data # pylint: disable=protected-access tape_lib.record_operation( f.__name__, all_tensors, original_tensors, tape_grad_fn) for ot, t in zip(original_tensors, all_tensors): copy_handle_data(ot, t) return nest.pack_sequence_as( structure=result, flat_sequence=all_tensors[:flat_result_len])
def capture_distributed_variable(self, variable, placeholder): """Add given distributed variable to captures with given placeholder.""" self._captures[ops.tensor_id(variable)] = (variable, placeholder) tape.record_operation("captured_value", [placeholder], [variable], lambda x: [x])
def _graph_mode_decorator(f, *args, **kwargs): """Implement custom gradient decorator for graph mode.""" # TODO(rsepassi): Add support for kwargs if kwargs: raise ValueError( "The custom_gradient decorator currently supports keywords " "arguments only when eager execution is enabled.") name = "CustomGradient-%s" % ops.uid() args = [ops.convert_to_tensor(x) for x in args] # Checking global and local variables attempts to ensure that no non-resource # Variables are added to the graph. current_var_scope = variable_scope.get_variable_scope() before_vars = set(current_var_scope.global_variables() + current_var_scope.local_variables()) with backprop.GradientTape() as tape: result, grad_fn = f(*args) after_vars = set(current_var_scope.global_variables() + current_var_scope.local_variables()) new_vars = after_vars - before_vars for v in new_vars: if not resource_variable_ops.is_resource_variable(v): raise TypeError( "All variables used by a function wrapped with @custom_gradient must " "be `ResourceVariable`s. Ensure that no `variable_scope` is created " "with `use_resource=False`.") # The variables that grad_fn needs to return gradients for are the set of # variables used that are *not* part of the inputs. variables = list(set(tape.watched_variables()) - set(args)) grad_argspec = tf_inspect.getfullargspec(grad_fn) variables_in_signature = ("variables" in grad_argspec.args or grad_argspec.varkw) if variables and not variables_in_signature: raise TypeError("If using @custom_gradient with a function that " "uses variables, then grad_fn must accept a keyword " "argument 'variables'.") if variables_in_signature and not variables: # User seems to intend to use variables but none were captured. if not variable_scope.get_variable_scope().use_resource: raise TypeError("If using @custom_gradient with a function that " "uses variables, the enclosing variable scope must " "have use_resource=True.") else: logging.warn("@custom_gradient grad_fn has 'variables' in signature, but " "no ResourceVariables were used on the forward pass.") flat_result = nest.flatten(result) all_tensors = flat_result + args + variables def tape_grad_fn(*result_grads): """Custom grad fn wrapper.""" result_grads = result_grads[:len(flat_result)] if variables: input_grads, variable_grads = grad_fn(*result_grads, variables=variables) if len(variable_grads) != len(variables): raise ValueError("Must return gradient for each variable from " "@custom_gradient grad_fn.") else: input_grads = grad_fn(*result_grads) variable_grads = [] # Need to return one value per input to the IdentityN, so pad the # gradients of the inputs of the custom_gradient function with the # gradients of the outputs as well. input_grads = nest.flatten(input_grads) return ([None] * len(flat_result)) + input_grads + variable_grads @ops.RegisterGradient(name) def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable """Custom grad fn wrapper.""" return tape_grad_fn(*result_grads) original_tensors = all_tensors with ops.get_default_graph().gradient_override_map({"IdentityN": name}): all_tensors = array_ops.identity_n(all_tensors) # Propagate handle data for happier shape inference for resource variables. for i, t in enumerate(original_tensors): if t.dtype == dtypes.resource and hasattr(t, "_handle_data"): all_tensors[i]._handle_data = t._handle_data # pylint: disable=protected-access tape_lib.record_operation( f.__name__, all_tensors, original_tensors, tape_grad_fn) for ot, t in zip(original_tensors, all_tensors): copy_handle_data(ot, t) return nest.pack_sequence_as( structure=result, flat_sequence=all_tensors[:len(flat_result)])