def testMultipleScopes(self): var1 = variables.Variable(0.0) var2 = variables.Variable(1.0) with tape.VariableWatcher() as variable_watcher1: var1.assign_add(1.0) with tape.VariableWatcher() as variable_watcher2: var2.assign_add(2.0) # variable_watcher1 should see both vars and variable_watcher2 only sees # var2 self.assertAllEqual(variable_watcher1.watched_variables(), (var1, var2)) self.assertAllEqual(variable_watcher2.watched_variables(), (var2,))
def testCreateVariables(self): with tape.VariableWatcher() as variable_watcher: var1 = variables.Variable(0.0) var2 = variables.Variable(1.0) var1.assign_add(1.0) var2.assign_add(2.0) self.assertAllEqual(variable_watcher.watched_variables(), (var1, var2))
def testNonTrainableVariables(self): var1 = variables.Variable(0.0) var2 = variables.Variable(1.0, trainable=False) with tape.VariableWatcher() as variable_watcher: var1.assign_add(1.0) var2.assign_add(2.0) self.assertAllEqual(variable_watcher.watched_variables(), (var1, ))
def _eager_mode_decorator(f, args, kwargs): """Implement custom gradient decorator for eager mode.""" with tape_lib.VariableWatcher() as variable_watcher: result, grad_fn = f(*args, **kwargs) args = nest.flatten(args) all_inputs = list(args) + list(kwargs.values()) # The variables that grad_fn needs to return gradients for are the set of # variables used that are *not* part of the inputs. variables = [ v.deref() # pylint: disable=g-complex-comprehension for v in set(v.ref() for v in variable_watcher.watched_variables()) if all(v.deref() is not i for i in all_inputs) ] grad_argspec = tf_inspect.getfullargspec(grad_fn) if (variables and ("variables" not in grad_argspec.args) and ("variables" not in grad_argspec.kwonlyargs) and not grad_argspec.varkw): raise TypeError( "@tf.custom_gradient grad_fn must accept keyword argument 'variables', " "since function uses variables: {}".format(variables)) flat_result = nest.flatten(result) # TODO(apassos) consider removing the identity below. flat_result = [gen_array_ops.identity(x) for x in flat_result] input_tensors = [ ops.convert_to_tensor(x) for x in list(args) + list(variables) ] recorded_inputs = input_tensors arg_count = len(args) def actual_grad_fn(*result_grads): """Custom grad fn wrapper.""" if variables: input_grads, variable_grads = grad_fn(*result_grads, variables=variables) if len(variable_grads) != len(variables): raise ValueError("Must return gradient for each variable from " "@custom_gradient grad_fn.") else: input_grads = grad_fn(*result_grads) variable_grads = [] flat_grads = nest.flatten(input_grads) if len(flat_grads) != arg_count: raise ValueError("custom_gradient function expected to return", arg_count, "gradients but returned", len(flat_grads), "instead.") return flat_grads + variable_grads tape_lib.record_operation(f.__name__, flat_result, recorded_inputs, actual_grad_fn) flat_result = list(flat_result) return nest.pack_sequence_as(result, flat_result)
def _graph_mode_decorator(f, args, kwargs): """Implement custom gradient decorator for graph mode.""" # TODO(rsepassi): Add support for kwargs if kwargs: raise ValueError( "The custom_gradient decorator currently supports keywords " "arguments only when eager execution is enabled.") name = "CustomGradient-%s" % ops.uid() default_graph = ops.get_default_graph() def convert_arg(x): x = ops.convert_to_tensor(x) # If graph building, be sure to capture all inputs if default_graph.building_function and x.graph != default_graph: x = default_graph.capture(x) return x args = nest.map_structure(convert_arg, args) # Checking global and local variables attempts to ensure that no non-resource # Variables are added to the graph. current_var_scope = variable_scope.get_variable_scope() before_vars = set([ v.ref() for v in current_var_scope.global_variables() + current_var_scope.local_variables() ]) with tape_lib.VariableWatcher() as variable_watcher: result, grad_fn = f(*args) args = nest.flatten(args) flat_result = nest.flatten(result) flat_result_len = len(flat_result) after_vars = set([ v.ref() for v in current_var_scope.global_variables() + current_var_scope.local_variables() ]) new_vars = after_vars - before_vars new_vars_list = [v.deref() for v in new_vars] for v in new_vars_list: if not resource_variable_ops.is_resource_variable(v): raise TypeError( "All variables used by a function wrapped with @custom_gradient must " "be `ResourceVariable`s. Ensure that no `variable_scope` is created " "with `use_resource=False`.") # The variables that grad_fn needs to return gradients for are the set of # variables used that are *not* part of the inputs. variables_in_tape = frozenset( [v.ref() for v in variable_watcher.watched_variables()]) variables_in_subgraph = frozenset([ v.ref() for v in _get_dependent_variables(input_ops=args, output_ops=flat_result) ]) variables = list( [v.deref() for v in variables_in_subgraph.union(variables_in_tape)]) grad_argspec = tf_inspect.getfullargspec(grad_fn) variables_in_signature = ("variables" in grad_argspec.args or "variables" in grad_argspec.kwonlyargs or grad_argspec.varkw) if variables and not variables_in_signature: raise TypeError( "@tf.custom_gradient grad_fn must accept keyword argument 'variables', " "since function uses variables: {}".format(variables)) if variables_in_signature and not variables: # User seems to intend to use variables but none were captured. logging.warn( "@custom_gradient grad_fn has 'variables' in signature, but " "no ResourceVariables were used on the forward pass.") all_tensors = flat_result + args + variables def tape_grad_fn(*result_grads): """Custom grad fn wrapper.""" result_grads = result_grads[:flat_result_len] if variables: input_grads, variable_grads = grad_fn(*result_grads, variables=variables) if len(variable_grads) != len(variables): raise ValueError("Must return gradient for each variable from " "@custom_gradient grad_fn.") else: input_grads = grad_fn(*result_grads) variable_grads = [] # Need to return one value per input to the IdentityN, so pad the # gradients of the inputs of the custom_gradient function with the # gradients of the outputs as well. input_grads = nest.flatten(input_grads) return ([None] * flat_result_len) + input_grads + variable_grads @ops.RegisterGradient(name) def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable """Custom grad fn wrapper.""" return tape_grad_fn(*result_grads) original_tensors = all_tensors with ops.get_default_graph().gradient_override_map({"IdentityN": name}): all_tensors = array_ops.identity_n(all_tensors) original_tensors = [ops.convert_to_tensor(x) for x in original_tensors] # Propagate handle data for happier shape inference for resource variables. for i, t in enumerate(original_tensors): if t.dtype == dtypes.resource and hasattr(t, "_handle_data"): all_tensors[i]._handle_data = t._handle_data # pylint: disable=protected-access tape_lib.record_operation(f.__name__, all_tensors, original_tensors, tape_grad_fn) for ot, t in zip(original_tensors, all_tensors): copy_handle_data(ot, t) return nest.pack_sequence_as(structure=result, flat_sequence=all_tensors[:flat_result_len])
def _graph_mode_decorator(f, args, kwargs): """Implement custom gradient decorator for graph mode.""" # TODO(rsepassi): Add support for kwargs if kwargs: raise ValueError( "The custom_gradient decorator currently supports keywords " "arguments only when eager execution is enabled.") name = "CustomGradient-%s" % ops.uid() args = nest.map_structure(ops.convert_to_tensor, args) # Checking global and local variables attempts to ensure that no non-resource # Variables are added to the graph. current_var_scope = variable_scope.get_variable_scope() before_vars = set([ v.ref() for v in current_var_scope.global_variables() + current_var_scope.local_variables() ]) with tape_lib.VariableWatcher() as variable_watcher: result, grad_fn = f(*args) args = nest.flatten(args) after_vars = set([ v.ref() for v in current_var_scope.global_variables() + current_var_scope.local_variables() ]) new_vars = after_vars - before_vars new_vars_list = [v.deref() for v in new_vars] for v in new_vars_list: if not resource_variable_ops.is_resource_variable(v): raise TypeError( "All variables used by a function wrapped with @custom_gradient must " "be `ResourceVariable`s. Ensure that no `variable_scope` is created " "with `use_resource=False`.") # It is possible for the caller to pass in an input that is from a different # graph. Even though this is not valid we filter these out if they are not # from the output graph to make it easier for some code to migrate to custom # gradients. inputs = nest.flatten(args) outputs = nest.flatten(result) graphs = {getattr(o, "graph", None) for o in outputs} # Not all results may be tensors. However, we want to ensure that all outputs # are from the same graph and use that to filter the inputs. graphs.discard(None) # Discard non-graph outputs if graphs: if len(graphs) > 1: raise ValueError("All graph outputs should be from the same graph") output_graph = graphs.pop() filtered_inputs = [] for i in inputs: if i.graph != output_graph: logging.warn("%s does not belong to output graph %s", i, output_graph) else: filtered_inputs.append(i) inputs = filtered_inputs # The variables that grad_fn needs to return gradients for are the set of # variables used that are *not* part of the inputs. variables_in_tape = frozenset([ v.ref() for v in variable_watcher.watched_variables() ]) - frozenset(v.ref() for v in inputs) variables_in_subgraph = frozenset([ v.ref() for v in get_dependent_variables(input_ops=inputs, output_ops=outputs) ]) variables = list( [v.deref() for v in variables_in_subgraph.union(variables_in_tape)]) grad_argspec = tf_inspect.getfullargspec(grad_fn) variables_in_signature = ("variables" in grad_argspec.args or grad_argspec.varkw) if variables and not variables_in_signature: raise TypeError("If using @custom_gradient with a function that " "uses variables, then grad_fn must accept a keyword " "argument 'variables'.") if variables_in_signature and not variables: # User seems to intend to use variables but none were captured. logging.warn("@custom_gradient grad_fn has 'variables' in signature, but " "no ResourceVariables were used on the forward pass.") flat_result = nest.flatten(result) flat_result_len = len(flat_result) all_tensors = flat_result + inputs + variables def tape_grad_fn(*result_grads): """Custom grad fn wrapper.""" result_grads = result_grads[:flat_result_len] if variables: input_grads, variable_grads = grad_fn(*result_grads, variables=variables) if len(variable_grads) != len(variables): raise ValueError("Must return gradient for each variable from " "@custom_gradient grad_fn.") else: input_grads = grad_fn(*result_grads) variable_grads = [] # Need to return one value per input to the IdentityN, so pad the # gradients of the inputs of the custom_gradient function with the # gradients of the outputs as well. input_grads = nest.flatten(input_grads) return ([None] * flat_result_len) + input_grads + variable_grads @ops.RegisterGradient(name) def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable """Custom grad fn wrapper.""" return tape_grad_fn(*result_grads) original_tensors = all_tensors with ops.get_default_graph().gradient_override_map({"IdentityN": name}): all_tensors = array_ops.identity_n(all_tensors) original_tensors = [ops.convert_to_tensor(x) for x in original_tensors] # Propagate handle data for happier shape inference for resource variables. for i, t in enumerate(original_tensors): if t.dtype == dtypes.resource and hasattr(t, "_handle_data"): all_tensors[i]._handle_data = t._handle_data # pylint: disable=protected-access tape_lib.record_operation( f.__name__, all_tensors, original_tensors, tape_grad_fn) for ot, t in zip(original_tensors, all_tensors): copy_handle_data(ot, t) return nest.pack_sequence_as( structure=result, flat_sequence=all_tensors[:flat_result_len])