def wrap_cached_variables(concrete_function): """Wraps the concrete function if it uses cached read tensors. This function creates a new concrete function that captures variables instead of the cached read tensors. Args: concrete_function: A Concrete function that maybe captures cached read tensors. Returns: A concrete function that wraps the original concrete function, which captures variables instead. If the original function did not capture any cached values, then the function is not wrapped and the original object is returned. """ outer_graph = func_graph_module.FuncGraph("{}_no_cache".format( concrete_function.graph.name)) captures = concrete_function.graph._captures # pylint: disable=protected-access mapped_captures = None remapped_captures = {} # Update the external captures to use read tensors generated in the outer # graph. with outer_graph.as_default(): for capture, placeholder in concrete_function.graph.captures: cached_variable = getattr(capture, "_cached_variable", None) if cached_variable is None: continue cached_variable = cached_variable() new_cached_value = cached_variable.read_value() remapped_captures[id(capture)] = captures[id(capture)] captures[id(capture)] = (new_cached_value, placeholder) mapped_captures = True if not mapped_captures: return concrete_function inner_concrete = defun.ConcreteFunction(concrete_function.graph) def wrap_function(*args): return inner_concrete._call_flat(args, inner_concrete.captured_inputs) # pylint:disable=protected-access args = nest.flatten(concrete_function.structured_input_signature, expand_composites=True) func_graph_module.func_graph_from_py_func(None, wrap_function, args=tuple(args), kwargs={}, func_graph=outer_graph) fn = defun.ConcreteFunction(outer_graph, function_spec=concrete_function._function_spec) # pylint: disable=protected-access fn._arg_keywords = concrete_function._arg_keywords # pylint: disable=protected-access fn._num_positional_args = concrete_function._num_positional_args # pylint: disable=protected-access # Return the captures to their original values for key, capture in remapped_captures.items(): captures[key] = capture return fn
def testMaybeSetStaticShapeScalarShape(self): def reshape(): v = array_ops.placeholder(dtypes.float32) t = array_ops.reshape(v, [-1]) return t with self.disableSetStaticShape(): graph_without_shape_propagation = func_graph.func_graph_from_py_func( "without_shape_propagation", reshape, [], {}) graph_with_shape_propagation = func_graph.func_graph_from_py_func( "with_shape_propagation", reshape, [], {}) self.assertCountEqual([ op.type for op in graph_without_shape_propagation.get_operations() ], [op.type for op in graph_with_shape_propagation.get_operations()])
def testMaybeSetStaticShape(self): shape = constant_op.constant([2, 5], dtype=dtypes.int32) def reshape(): v = array_ops.zeros([10]) return array_ops.reshape(v, shape) with self.disableSetStaticShape(): graph_without_shape_propagation = func_graph.func_graph_from_py_func( "without_shape_propagation", reshape, [], {}) graph_with_shape_propagation = func_graph.func_graph_from_py_func( "with_shape_propagation", reshape, [], {}) self.assertCountEqual([ op.type for op in graph_without_shape_propagation.get_operations() ], [op.type for op in graph_with_shape_propagation.get_operations()])
def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" if isinstance(pred, bool): raise TypeError("pred must not be a Python bool", pred) if not name: name = "cond" with ops.name_scope(name) as scope: true_name = util.unique_fn_name(scope, "true") false_name = util.unique_fn_name(scope, "false") # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = ops.get_default_graph( )._add_control_dependencies pred = ops.convert_to_tensor(pred) if (tensor_util.is_tensor(pred) and (pred.shape.dims is None or pred.shape.dims)): pred = array_ops.squeeze_v2(pred) true_graph = func_graph_module.func_graph_from_py_func( true_name, true_fn, [], {}, func_graph=util.CondBranchFuncGraph( true_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies, op_return_value=pred) false_graph = func_graph_module.func_graph_from_py_func( false_name, false_fn, [], {}, func_graph=util.CondBranchFuncGraph( false_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies, op_return_value=pred) verify_captures(_COND, [true_graph, false_graph]) return _build_cond(pred, true_graph, false_graph, true_graph.external_captures, false_graph.external_captures, building_gradient=False, name=scope)
def _compile_function(func, args, scope, control_outputs, allow_external_captures=False): parent_graph = ops.get_default_graph() # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = parent_graph._add_control_dependencies # pylint: disable=protected-access # Functions inherit frontend attributes and the gradient override map from the # parent graph. proto = xla_data_pb2.FrontendAttributes() value = parent_graph._attr_scope_map.get(scopes.FRONTEND_ATTRIBUTES_NAME) # pylint: disable=protected-access if value: proto.ParseFromString(value.s) attribute = attr_value_pb2.AttrValue(s=proto.SerializeToString()) gradient_override_map = dict(parent_graph._gradient_override_map) # pylint: disable=protected-access def func_wrapper(*args, **kwargs): # Add the frontend attributes to the current attributes. g = ops.get_default_graph() attributes = dict(g._attr_scope_map) # pylint: disable=protected-access attributes[scopes.FRONTEND_ATTRIBUTES_NAME] = attribute with g._attr_scope(attributes): # pylint: disable=protected-access with g.gradient_override_map(gradient_override_map): return func(*args, **kwargs) func_name = util.unique_fn_name(scope, "func") captured_args = ops.convert_n_to_tensor(args) # Compile the function to a graph. func_graph = func_graph_module.func_graph_from_py_func( func_name, func_wrapper, captured_args, {}, add_control_dependencies=add_control_dependencies) # Add the external captures (resources) to arguments. for t in func_graph.external_captures: if not allow_external_captures and t.dtype != dtypes.resource: raise _InvalidCaptureException(t.name) captured_args += func_graph.external_captures # Add any control outputs. Autograph will add control outputs to the graph # automatically, so only add ones which are not already present. for o in control_outputs: if not o in func_graph.control_outputs: func_graph.control_outputs.extend([o]) # Fix shape inference for the gradients and extract_outside_compilation_pass. for op in func_graph.get_operations(): output_shapes = [out.get_shape() for out in op.outputs] # pylint: disable=protected-access op._set_shape_list_attr("_output_shapes", output_shapes) op._set_shape_list_attr("_xla_inferred_shapes", output_shapes) # pylint: enable=protected-access return func_graph, captured_args
def testMaybeSetStaticShape(self): shape = constant_op.constant([2, 5], dtype=dtypes.int32) def reshape(): v = array_ops.zeros([10]) return array_ops.reshape(v, shape) # This test needs a placeholder which means we need to construct a graph. with ops.Graph().as_default(): with self.disableSetStaticShape(): graph_without_shape_propagation = func_graph.func_graph_from_py_func( "without_shape_propagation", reshape, [], {}) graph_with_shape_propagation = func_graph.func_graph_from_py_func( "with_shape_propagation", reshape, [], {}) self.assertCountEqual( [op.type for op in graph_without_shape_propagation.get_operations()], [op.type for op in graph_with_shape_propagation.get_operations()])
def _WhileGrad(op, *grads): # pylint: disable=invalid-name """The gradient of a While op produced by while_loop.""" body_graph = _get_body_graph(op) # Replace None gradients with zeros. This is needed because `grads` could have # None incoming gradients for the TensorLists. If we pass None's through, the # custom gradient of TensorListPopBack will create an EmptyTensorList inside # the FuncGraph which is undesirable. # TODO(b/80444525): There might be an issue with treating no gradient as zero # gradient in certain cases. Consider replacing None gradients with Zeros # for accumulators only. grads = [ g if g is not None else array_ops.zeros_like(output) for g, output in zip(grads, op.outputs) ] body_grad_graph, args = _create_grad_func( body_graph, grads, util.unique_grad_fn_name(body_graph.name), op) intermediate_tensors = _get_intermediates(body_grad_graph) for intermediate_tensor in intermediate_tensors: tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=_get_tensor_convertible_shape(intermediate_tensor.shape)) with body_grad_graph.as_default(): tensor_list_ph = body_grad_graph.capture(tensor_list, whitelisted=True) # Push the intermediate tensor to the tensor list. appended_tensor_list = list_ops.tensor_list_push_back(tensor_list_ph, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_grad_graph.outputs.append(appended_tensor_list) def grad_cond(counter, max_iters, *unused_args): return counter < max_iters loop_vars = args + body_grad_graph.external_captures grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name) cond_grad_graph = func_graph_module.func_graph_from_py_func( grad_cond_name, grad_cond, loop_vars, {}, func_graph=util.WhileCondFuncGraph(grad_cond_name)) assert len(loop_vars) == len(body_grad_graph.inputs) assert len(loop_vars) == len(body_grad_graph.outputs) assert len(loop_vars) == len(cond_grad_graph.inputs) outputs = gen_functional_ops._while( loop_vars, util.create_new_tf_function(cond_grad_graph), util.create_new_tf_function(body_grad_graph), output_shapes=[t.shape for t in body_grad_graph.outputs], name="%s_grad" % op.name) _copy_handle_data(body_grad_graph.outputs, outputs) _maybe_set_lowering_attr(outputs[0].op) # outputs[0] is the loop counter. # outputs[1] is the total number of loop iterations. return outputs[2:2 + len(op.inputs)]
def indexed_case(branch_index, branch_fns, name="indexed_case"): """Like conv_v2, except emits a Case op instead of an If.""" if isinstance(branch_index, int): raise TypeError("branch_index must not be a Python int", branch_index) with ops.name_scope(name) as scope: branch_names = [ util.unique_fn_name(scope, "branch{}".format(b)) for b in range(len(branch_fns)) ] # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = ops.get_default_graph()._add_control_dependencies branch_index = ops.convert_to_tensor(branch_index, name="branch_index") branch_graphs = [] for branch_name, branch_fn in zip(branch_names, branch_fns): branch_graphs.append( func_graph_module.func_graph_from_py_func( branch_name, branch_fn, [], {}, func_graph=util.CondBranchFuncGraph( branch_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies, op_return_value=branch_index)) verify_captures(_CASE, branch_graphs) return _build_case( branch_index, branch_graphs, [g.external_captures for g in branch_graphs], name=scope)
def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" if isinstance(pred, bool): raise TypeError("pred must not be a Python bool", pred) if not name: name = "cond" with ops.name_scope(name) as scope: true_name = util.unique_fn_name(scope, "true") false_name = util.unique_fn_name(scope, "false") # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = ops.get_default_graph( )._add_control_dependencies pred = ops.convert_to_tensor(pred) true_graph = func_graph_module.func_graph_from_py_func( true_name, true_fn, [], {}, func_graph=util.CondBranchFuncGraph( true_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies, op_return_value=pred) false_graph = func_graph_module.func_graph_from_py_func( false_name, false_fn, [], {}, func_graph=util.CondBranchFuncGraph( false_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies, op_return_value=pred) outputs = _build_cond(pred, true_graph, false_graph, true_graph.external_captures, false_graph.external_captures, name=scope) return func_graph_module.pack_sequence_as( true_graph.structured_outputs, outputs)
def wrap_function(fn, signature, name=None): """Wraps the TF 1.x function fn into a graph function. The python function `fn` will be called once with symbolic arguments specified in the `signature`, traced, and turned into a graph function. Any variables created by `fn` will be owned by the object returned by `wrap_function`. The resulting graph function can be called with tensors which match the signature. ```python def f(x, do_add): v = tf.Variable(5.0) if do_add: op = v.assign_add(x) else: op = v.assign_sub(x) with tf.control_dependencies([op]): return v.read_value() f_add = tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), True]) assert float(f_add(1.0)) == 6.0 assert float(f_add(1.0)) == 7.0 # Can call tf.compat.v1.wrap_function again to get a new trace, a new set # of variables, and possibly different non-template arguments. f_sub= tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), False]) assert float(f_sub(1.0)) == 4.0 assert float(f_sub(1.0)) == 3.0 ``` Args: fn: python function to be wrapped signature: the placeholder and python arguments to be passed to the wrapped function name: Optional. The name of the function. Returns: the wrapped graph function. """ holder = VariableHolder(fn) func_graph_name = "wrapped_function" if name is not None: func_graph_name = "wrapped_function_" + name return WrappedFunction( func_graph.func_graph_from_py_func( func_graph_name, holder, args=None, kwargs=None, signature=signature, add_control_dependencies=False, collections={}), variable_holder=holder, signature=signature)
def _create_grad_func(ys, xs, grads, cond_graph, body_graph, name, while_op, maximum_iterations): """Builds and returns the gradient FuncGraph of `func_graph` and its args. The returned grad_func_graph must be called with the returned args + grad_func_graph.captures. Args: ys: A `Tensor` or list of tensors to be differentiated. xs: A `Tensor` or list of tensors to be used for differentiation. grads: The incoming grads for `ys`. cond_graph: FuncGraph for the forward cond function. body_graph: FuncGraph for the forward body function. name: Name of the returned gradient function. while_op: The forward While op. maximum_iterations: Tensor. The maximum number of iterations. Returns: 2-tuple of (grad_func_graph, args). """ assert len(ys) == len(grads) total_iters = while_op.outputs[0] counter = constant_op.constant(0, dtype=total_iters.dtype, name="grad_counter") args = [counter, maximum_iterations, total_iters] + list(grads) # Note: The returned function does not have `args` in the list of # `external_captures`. grad_func_graph = func_graph_module.func_graph_from_py_func( name, lambda *args: _grad_fn(ys, xs, args, body_graph), args, {}, func_graph=_WhileBodyGradFuncGraph(name, cond_graph, body_graph, maximum_iterations, while_op)) # Add the popped accumulators to the list of outputs. for internal_capture in grad_func_graph.internal_captures: if internal_capture in grad_func_graph.popped_tensor_lists: new_output = grad_func_graph.popped_tensor_lists[internal_capture] elif internal_capture.dtype == dtypes.resource: new_output = internal_capture else: raise ValueError( "Tensor %s is in list of internal_captures but is" " neither a resource nor is in popped_tensor_lists." % str(internal_capture)) grad_func_graph.outputs.append(new_output) grad_func_graph.structured_outputs.append(new_output) return grad_func_graph, args
def wrap_function(self, fn, signature, name=None): """Wrap a TF 1.X function and save to functions dictionary.""" func_graph.func_graph_from_py_func( None, # Name is unused. self._variable_holder.call_with_variable_creator_scope(fn), args=None, kwargs=None, signature=signature, add_control_dependencies=False, func_graph=self.graph) # This code relies on questional behavior from `func_graph_from_py_func`. # If an existing FuncGraph is passed into the `func_graph` arg, the inputs # and structured outputs are overwritten. Pretty sure this is a bug, # because structured outputs doesn't match up with the outputs... fn_inputs = self.graph.inputs[:-len(self.graph.captures)] fn_outputs = self.graph.structured_outputs wrapped_function = self._wrapped_function.prune(fn_inputs, fn_outputs) name = name or fn.__name__ self._functions[name] = wrapped_function return wrapped_function
def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" if isinstance(pred, bool): raise TypeError("pred must not be a Python bool", pred) if not name: name = "cond" with ops.name_scope(name) as scope: true_name = util.unique_fn_name(scope, "true") false_name = util.unique_fn_name(scope, "false") # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = util.in_defun() pred = ops.convert_to_tensor(pred) true_graph = func_graph_module.func_graph_from_py_func( true_name, true_fn, [], {}, func_graph=util.CondBranchFuncGraph( true_name, read_only_collections=False), add_control_dependencies=add_control_dependencies, op_return_value=pred) false_graph = func_graph_module.func_graph_from_py_func( false_name, false_fn, [], {}, func_graph=util.CondBranchFuncGraph( false_name, read_only_collections=False), add_control_dependencies=add_control_dependencies, op_return_value=pred) outputs = _build_cond(pred, true_graph, false_graph, true_graph.external_captures, false_graph.external_captures, name=scope) return func_graph_module.pack_sequence_as(true_graph.structured_outputs, outputs)
def wrap_function(fn, signature, name=None): """Wraps the TF 1.x function fn into a graph function. The python function `fn` will be called once with symbolic arguments specified in the `signature`, traced, and turned into a graph function. Any variables created by `fn` will be owned by the object returned by `wrap_function`. The resulting graph function can be called with tensors which match the signature. ```python def f(x, do_add): v = tf.Variable(5.0) if do_add: op = v.assign_add(x) else: op = v.assign_sub(x) with tf.control_dependencies([op]): return v.read_value() f_add = tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), True]) assert float(f_add(1.0)) == 6.0 assert float(f_add(1.0)) == 7.0 # Can call tf.compat.v1.wrap_function again to get a new trace, a new set # of variables, and possibly different non-template arguments. f_sub= tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), False]) assert float(f_sub(1.0)) == 4.0 assert float(f_sub(1.0)) == 3.0 ``` Args: fn: python function to be wrapped signature: the placeholder and python arguments to be passed to the wrapped function name: Optional. The name of the function. Returns: the wrapped graph function. """ holder = VariableHolder(fn) fn = function.Function( func_graph.func_graph_from_py_func( name, holder, args=None, kwargs=None, signature=signature, add_control_dependencies=False), signature=signature) fn._variable_holder = holder return fn
def _create_grad_func(ys, xs, grads, cond_graph, body_graph, name, while_op, max_iters): """Builds and returns the gradient FuncGraph of `func_graph` and its args. The returned grad_func_graph must be called with the returned args + grad_func_graph.captures. Args: ys: A `Tensor` or list of tensors to be differentiated. xs: A `Tensor` or list of tensors to be used for differentiation. grads: The incoming grads for `ys`. cond_graph: FuncGraph for the forward cond function. body_graph: FuncGraph for the forward body function. name: Name of the returned gradient function. while_op: The forward While op. max_iters: the maximum number of iterations, or None if no limit. Returns: 2-tuple of (grad_func_graph, args). """ assert len(ys) == len(grads) total_iters = while_op.outputs[0] counter = constant_op.constant( 0, dtype=total_iters.dtype, name="grad_counter") args = [counter, total_iters] + list(grads) # Note: The returned function does not have `args` in the list of # `external_captures`. grad_func_graph = func_graph_module.func_graph_from_py_func( name, lambda *args: _grad_fn(ys, xs, args, body_graph), args, {}, func_graph=_WhileBodyGradFuncGraph(name, cond_graph, body_graph, max_iters)) # Add the popped accumulators to the list of outputs. for internal_capture in grad_func_graph.internal_captures: if internal_capture in grad_func_graph.popped_tensor_lists: new_output = grad_func_graph.popped_tensor_lists[internal_capture] elif internal_capture.dtype == dtypes.resource: new_output = internal_capture else: raise ValueError("Tensor %s is in list of internal_captures but is" " neither a resource nor is in popped_tensor_lists." % str(internal_capture)) grad_func_graph.outputs.append(new_output) grad_func_graph.structured_outputs.append(new_output) return grad_func_graph, args
def _wrap_function(self, fn, args=None, kwargs=None, signature=None, name=None): """Internal wrap function method with extended func_graph arguments.""" fn_with_filter_and_scope, returned_ops = _filter_returned_ops( self._variable_holder.call_with_variable_creator_scope(fn)) func_graph.func_graph_from_py_func( None, # Name is unused. fn_with_filter_and_scope, args=args, kwargs=kwargs, signature=signature, add_control_dependencies=False, func_graph=self.graph) # This code relies on questional behavior from `func_graph_from_py_func`. # If an existing FuncGraph is passed into the `func_graph` arg, the inputs # and structured outputs are overwritten. Pretty sure this is a bug, # because structured outputs doesn't match up with the outputs... fn_inputs = self.graph.inputs[:-len(self.graph.captures)] # Return filtered ops to the flattened outputs. flat_fn_outputs = nest.flatten(self.graph.structured_outputs) for index, op in returned_ops.items(): flat_fn_outputs[index] = op fn_outputs = nest.pack_sequence_as(self.graph.structured_outputs, flat_fn_outputs) name = name or fn.__name__ wrapped_function = self._wrapped_function.prune( fn_inputs, fn_outputs, name, self.graph.structured_input_signature) self._functions[name] = wrapped_function return wrapped_function
def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" if isinstance(pred, bool): raise TypeError("pred must not be a Python bool", pred) if not name: name = "cond" with ops.name_scope(name) as scope: true_name = util.unique_fn_name(scope, "true") false_name = util.unique_fn_name(scope, "false") # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = ops.get_default_graph()._add_control_dependencies pred = ops.convert_to_tensor(pred) true_graph = func_graph_module.func_graph_from_py_func( true_name, true_fn, [], {}, func_graph=util.CondBranchFuncGraph( true_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies, op_return_value=pred) false_graph = func_graph_module.func_graph_from_py_func( false_name, false_fn, [], {}, func_graph=util.CondBranchFuncGraph( false_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies, op_return_value=pred) verify_captures(true_graph, false_graph) return _build_cond(pred, true_graph, false_graph, true_graph.external_captures, false_graph.external_captures, name=scope)
def _create_grad_func(ys, xs, grads, func_graph, name, while_op): """Builds and returns the gradient FuncGraph of `func_graph` and its args. The returned grad_func_graph must be called with the returned args + grad_func_graph.captures. Args: ys: A `Tensor` or list of tensors to be differentiated. xs: A `Tensor` or list of tensors to be used for differentiation. grads: The incoming grads for `ys`. func_graph: FuncGraph for the forward body function. name: Name of the returned gradient function. while_op: The forward While op. Returns: 2-tuple of (grad_func_graph, args). """ assert len(ys) == len(grads) counter = constant_op.constant(0.) # TODO(srbs): For nested while loops will need to lookup this value from # the accumulator of the enclosing while loop. For now use as is assuming # there is no nesting. total_iters = while_op.outputs[0] args = [counter, total_iters] + list(grads) # Note: The returned function does not have `args` in the list of # `external_captures`. grad_func_graph = func_graph_module.func_graph_from_py_func( name, lambda *args: _grad_fn(ys, xs, args, func_graph), args, {}, func_graph=_WhileBodyGradFuncGraph(name, func_graph, while_op)) # Add the popped accumulators to the list of outputs. for internal_capture in grad_func_graph.internal_captures: if internal_capture in grad_func_graph.popped_tensor_lists: grad_func_graph.outputs.append( grad_func_graph.popped_tensor_lists[internal_capture]) elif internal_capture.dtype == dtypes.resource: grad_func_graph.outputs.append(internal_capture) else: raise ValueError( "Tensor %s is in list of internal_captures but is" " neither a resource nor is in popped_tensor_lists." % str(internal_capture)) return grad_func_graph, args
def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None): """Create a `ConcreteFunction` from `args` and `kwargs`.""" self.tracing_count += 1 if self.input_signature is None: arglen = len(args) else: arglen = len(self.input_signature) base_arg_names = self._function_spec.arg_names[:arglen] num_missing_args = arglen - len(self._function_spec.arg_names) missing_arg_names = [self._function_spec.vararg_name] * num_missing_args # Produce a list of missing args of the form ["arg_0", "arg_1", ...], # where arg is based on the self._function_spec.vararg_name. missing_arg_names = [ "%s_%d" % (arg, i) for i, arg in enumerate(missing_arg_names) ] arg_names = base_arg_names + missing_arg_names graph_function = _function.ConcreteFunction( func_graph_module.func_graph_from_py_func( self._name, self._python_function, args, kwargs, self.input_signature, autograph=self._autograph, autograph_options=self._autograph_options, arg_names=arg_names, override_flat_arg_shapes=override_flat_arg_shapes, capture_by_value=self._capture_by_value, add_control_dependencies=False, ), self._function_attributes, # Tell the ConcreteFunction to clean up its graph once it goes out of # scope. This is not the default behavior since it gets used in some # places (like Keras) where the FuncGraph lives longer than the # ConcreteFunction. shared_func_graph=False, ) return graph_function
def _create_grad_func(func_graph, grads, name, while_op): """Builds and returns the gradient FuncGraph of `func_graph` and its args. The returned grad_func_graph must be called with the returned args + grad_func_graph.captures. Args: func_graph: FuncGraph for the forward body function. grads: The incoming grads for `func_graph`'s outputs. name: Name of the returned gradient function. while_op: The forward While op. Returns: 2-tuple of (grad_func_graph, args). """ assert len(func_graph.outputs) == len(grads) loop_counter = constant_op.constant(0.) # TODO(srbs): For nested while loops will need to lookup this value from # the accumulator of the enclosing while loop. For now use as is assuming # there is no nesting. num_iters_t = while_op.outputs[0] args = [loop_counter, num_iters_t] + grads # Note: The returned function does not have `args` in the list of # `external_captures`. grad_func_graph = func_graph_module.func_graph_from_py_func( name, lambda *args: _grad_fn(func_graph, args), args, {}, func_graph=_WhileBodyGradFuncGraph(name, func_graph)) # Add the popped accumulators to the list of outputs. for internal_capture in grad_func_graph.internal_captures: grad_func_graph.outputs.append( grad_func_graph.popped_tensor_lists[internal_capture]) return grad_func_graph, args
def wrap_input_receiver_fn(self, input_receiver_fn): """Converts an input receiver function to one or more concrete functions. Input receiver functions are python functions with no arguments. Placeholders are created within the function and used to receive inputs to the model. The function (or multiple functions) generated depends on the InputReceiver object returned by `input_receiver_fn`. Generally, the returned function will have inputs and outputs: input_receiver(**receiver_tensors) --> features or (if the InputReceiver returns labels): input_receiver(**receiver_tensors) --> features, labels __Alternate Receiver Tensors__ The InputReceiver may have alternate receiver tensors, in which case additional concrete functions are generated. Example: InputReceiver.receiver_tensors_alternatives = { 'alt_input_1': Tensor, 'alt_input_2': { 'tensor_1': Tensor, 'tensor_2': Tensor } } This will generate concrete functions: input_receiver_alt_input_1(input) --> features input_receiver_alt_input_2(tensor_1, tensor_2) --> features Args: input_receiver_fn: a no-argument function that returns an `InputReceiver` object. Returns: A list of tuples of (concrete function, receiver name). The name of the default input receiver is `None`. """ ret = [None] def fn(): ret[0] = input_receiver = input_receiver_fn() features = input_receiver.features labels = getattr(input_receiver, 'labels', None) if labels is None: return features return features, labels func_graph.func_graph_from_py_func( None, # Name is unused. self._variable_holder.call_with_variable_creator_scope(fn), args=None, kwargs=None, signature=[], add_control_dependencies=False, func_graph=self.graph) functions = [] input_receiver = ret[0] wrapped_input_receiver_fn = _prune_receiver_tensors( self._wrapped_function, receiver_tensors=input_receiver.receiver_tensors, outputs=self.graph.structured_outputs, name=_input_receiver_fn_name(None)) functions.append((wrapped_input_receiver_fn, None)) receiver_tensors_alternatives = getattr( input_receiver, 'receiver_tensors_alternatives', None) if receiver_tensors_alternatives: for receiver_name, receiver_tensors_alt in ( six.iteritems(receiver_tensors_alternatives)): receiver_tensors_alt = _canonicalize_receiver_tensors( receiver_tensors_alt) wrapped_input_receiver_fn = _prune_receiver_tensors( self._wrapped_function, receiver_tensors=receiver_tensors_alt, outputs=self.graph.structured_outputs, name=_input_receiver_fn_name(receiver_name)) functions.append((wrapped_input_receiver_fn, receiver_name)) return functions
def _WhileGrad(op, *grads): # pylint: disable=invalid-name """The gradient of a While op produced by while_loop.""" # Note that op is not always the same as while_op because the gradient tape, # for eager mode compatibility, forgets information about the proper op. Since # the loop cannot run in eager mode, however, we can safely introspect into # the graph here. while_op = op.outputs[0].op cond_graph = _get_graph(while_op, "cond") body_graph = _get_graph(while_op, "body") orig_num_params = len(body_graph.outputs) maximum_iterations = op.get_attr( "_maximum_iterations") if _is_in_xla_context() else None parallel_iterations = op.get_attr("parallel_iterations") assert not _is_in_xla_context() or maximum_iterations is not None maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations) grads = [_preprocess_grad(grad, body_out, while_out) for grad, body_out, while_out in zip(grads, body_graph.outputs, while_op.outputs)] # We compute the gradient for the sub-graph between trainable ys and xs # with non-None incoming gradients. We later pad the None's to the list of # outputs. ys, xs, non_none_grads = zip(*[(y, x, grad) for (y, x, grad) in zip( body_graph.outputs, body_graph.inputs, grads) if grad is not None]) body_grad_graph, args = _create_grad_func( ys, xs, non_none_grads, cond_graph, body_graph, util.unique_grad_fn_name(body_graph.name), op, maximum_iterations) if body_grad_graph.while_op_needs_rewrite: # Modify 'op' to output the intermediate accumulators needed by the grad # function. # NOTE(skyewm): if there are any active sessions, this modification to `op` # may make them unrunnable! cond_graph.name += "_rewritten" body_graph.name += "_rewritten" new_inputs = body_grad_graph.empty_tensor_lists new_outputs = body_graph.outputs[orig_num_params:] while_op._set_func_attr("cond", util.create_new_tf_function(cond_graph)) while_op._set_func_attr("body", util.create_new_tf_function(body_graph)) while_op._set_type_list_attr("T", body_graph.output_types) while_op._set_shape_list_attr("output_shapes", body_graph.output_shapes) while_op._add_while_inputs(new_inputs) while_op._add_outputs([t.dtype for t in new_outputs], [t.shape for t in new_outputs]) _copy_handle_data(new_outputs, op.outputs[orig_num_params:]) captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph, while_op) loop_vars = args + captured_inputs # This modifies body_grad_graph. loop_vars = while_v2_indexed_slices_rewriter.rewrite_grad_indexed_slices( grads, body_grad_graph, loop_vars, while_op.inputs) def grad_cond(counter, max_iters, *unused_args): return counter < max_iters grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name) cond_grad_graph = func_graph_module.func_graph_from_py_func( grad_cond_name, grad_cond, loop_vars, {}, func_graph=util.WhileCondFuncGraph(grad_cond_name)) _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars)) outputs = gen_functional_ops._while( loop_vars, util.create_new_tf_function(cond_grad_graph), util.create_new_tf_function(body_grad_graph), output_shapes=[t.shape for t in body_grad_graph.outputs], parallel_iterations=parallel_iterations, name="%s_grad" % while_op.name) grad_op = outputs[0].op _copy_handle_data(body_grad_graph.outputs, outputs) util.maybe_set_lowering_attr(grad_op) _maybe_set_maximum_iterations_attr(grad_op, maximum_iterations) # See comment in while_loop. outputs = [array_ops.identity(t) for t in outputs] return _get_structured_grad_output(outputs, grads, body_grad_graph)
def while_loop(cond, body, loop_vars, shape_invariants=None, maximum_iterations=None, name=None, return_same_structure=True): """Like tf.while_loop, except emits a single While op.""" maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations) # Keep the original loop_vars around to know which args were TensorArrays. orig_loop_vars = loop_vars # Cache its length since we use it at multiple places below. len_orig_loop_vars = len(orig_loop_vars) # Convert TensorArrays to their flow variables. These get converted back to # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and # `wrapped_body` below. loop_vars = list(_tensor_array_to_flow(orig_loop_vars)) loop_vars = nest.map_structure( ops.internal_convert_to_tensor_or_indexed_slices, loop_vars) if shape_invariants is not None: nest.assert_same_structure(orig_loop_vars, shape_invariants) else: shape_invariants = nest.map_structure(lambda t: t.shape, loop_vars) if not name: name = "while" with ops.name_scope(name) as scope: with ops.name_scope(None): cond_name = util.unique_fn_name(scope, "cond") body_name = util.unique_fn_name(scope, "body") loop_counter = constant_op.constant( 0, dtype=maximum_iterations.dtype if maximum_iterations is not None else None, name="loop_counter") # Add loop counter needed for computing gradients. loop_vars = [loop_counter] + loop_vars shape_invariants = type(shape_invariants)([tensor_shape.scalar() ]) + shape_invariants # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = ops.get_default_graph( )._add_control_dependencies # Build a `cond` wrapper that can handle the extra counter loop_var. def wrapped_cond(loop_counter, *args): # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. if maximum_iterations is None: return cond(*_pack_sequence_as(orig_loop_vars, args)) else: return math_ops.logical_and( loop_counter < maximum_iterations, cond(*_pack_sequence_as(orig_loop_vars, args))) cond_graph = func_graph_module.func_graph_from_py_func( cond_name, wrapped_cond, loop_vars, {}, signature=_build_signature(loop_vars, shape_invariants), func_graph=util.WhileCondFuncGraph(cond_name), add_control_dependencies=add_control_dependencies) # Add external_captures of cond to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. loop_vars = loop_vars + cond_graph.external_captures shape_invariants = shape_invariants + type(shape_invariants)( [t.shape for t in cond_graph.external_captures]) def wrapped_body(loop_counter, *args): """Loop body augmented with counter update. Args: loop_counter: Loop counter which needs to be incremented in the body. *args: List of args args[:len_orig_loop_vars] - Args for the original loop body. args[len_orig_loop_vars:] - External captures of cond. These get passed through as is. Returns: A list of tensors the same length as args. """ # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. outputs = body( *_pack_sequence_as(orig_loop_vars, args[:len_orig_loop_vars])) if not nest.is_sequence(outputs): outputs = [outputs] # Compare the structure of input and output of body converting the # top-level tuples to list to be compatible with legacy while_loop. nest.assert_same_structure(list(outputs), list(orig_loop_vars)) outputs = _tensor_array_to_flow(outputs) # Return the external_captures of cond_graph as is, i.e., treat them as # loop invariants. # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. return [loop_counter + 1] + list(outputs) + list( args[len_orig_loop_vars:]) body_graph = func_graph_module.func_graph_from_py_func( body_name, wrapped_body, loop_vars, {}, signature=_build_signature(loop_vars, shape_invariants), func_graph=util.WhileBodyFuncGraph(body_name), add_control_dependencies=add_control_dependencies) # Add external captures of body to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. loop_vars = loop_vars + body_graph.external_captures # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. body_graph.outputs.extend(body_graph.internal_captures) # Capture `external_captures` of `body_graph` in `cond_graph` so that it # expects to receive those as arguments. # TODO(b/118457764): Dedup tensors that are captured in both the cond and # body. This logic already exists in cond_v2. with cond_graph.as_default(): for external_capture in body_graph.external_captures: assert external_capture not in cond_graph.captures, ( "Looks like both cond and body are capturing the same tensor %s. " "This is not supported yet. For now consider passing," " this as a loop variable." % str(external_capture)) cond_graph.capture(external_capture) # Make sure that the shapes of the loop outputs are compatible with the # shape invariants, or the shapes of the loop vars if the invariants are not # specified. num_flattened_outputs = len(nest.flatten(orig_loop_vars)) _check_shapes_compat( body_graph.outputs[1:1 + num_flattened_outputs], nest.flatten(shape_invariants[1:1 + len_orig_loop_vars]), nest.flatten(loop_vars[1:1 + len_orig_loop_vars])) flattened_loop_vars = nest.flatten(loop_vars) _check_num_inputs_outputs(cond_graph, body_graph, len(flattened_loop_vars)) outputs = gen_functional_ops._while( flattened_loop_vars, util.create_new_tf_function(cond_graph), util.create_new_tf_function(body_graph), output_shapes=[t.shape for t in body_graph.outputs], name=scope) _copy_handle_data(body_graph.outputs, outputs) util.maybe_set_lowering_attr(outputs[0].op) _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations) # Return identities for each output of the While op, rather than the output # of the While op directly. This makes pruning work if the output of # while_loop() is fetched: the lowering pass converts the While outputs into # IdentityN outputs, which if fetched will cause all ops in the body to be # run (since it takes all exit ops as input). After lowering, each output # identity op will end up with only the appropriate exit op as input. outputs = tuple(array_ops.identity(t) for t in outputs) # First var is loop counter. outputs = _pack_sequence_as(orig_loop_vars, outputs[1:1 + num_flattened_outputs]) if return_same_structure: return outputs flattened_outputs = nest.flatten(outputs) if len(flattened_outputs) == 1: return flattened_outputs[0] else: return outputs
def wrap_function(self, fn, signature, name=None): """Wraps a TF 1.X function and returns an eager-compatible function. All functions wrapped in the same `WrappedGraph` will have access to the same graph (`tf.get_default_graph` to get the graph object within a function, or `WrappedGraph.graph` to get the graph outside a function). Variables created within the function will be added to the `variables` list. Function inputs: All inputs to the function must be tensors (nested ok), with their shapes and dtypes defined in the `signature` argument. Function outputs: * The 1.X function may return tensors, variables, and ops. The wrapped eager-compatible function will always return tensors in the same nested structure. * Variables are replaced with a tensor containing the latest read values. * Returned ops are executed, and replaced with None. * The order of op execution and variable reads in the return is nondeterministic. For example: ``` def update_var(x): v = tf.Variable(0) op = tf.compat.v1.assign(v, x).op return v, op g = WrappedGraph() fn = g.wrap_function(update_var) read_value, _ = fn(tf.constant(3)) print(read_value.numpy()) # could be 0 or 3 print(g.variables[0].numpy()) # always 3 ``` To ensure that ops in the function are executed (e.g. ops added to the `tf.GraphKeys.UPDATE_OPS` collection), include them in the function returns. Args: fn: a 1.X tensorflow function. signature: a possibly nested sequence of `TensorSpecs` specifying the shapes and dtypes of the arguments. name: an optional string name for the function. The function will be saved with key `name` in the `functions` dictionary. Returns: An eager-compatible function. """ fn_with_filter_and_scope, returned_ops = _filter_returned_ops( self._variable_holder.call_with_variable_creator_scope(fn)) func_graph.func_graph_from_py_func( None, # Name is unused. fn_with_filter_and_scope, args=None, kwargs=None, signature=signature, add_control_dependencies=False, func_graph=self.graph) # This code relies on questional behavior from `func_graph_from_py_func`. # If an existing FuncGraph is passed into the `func_graph` arg, the inputs # and structured outputs are overwritten. Pretty sure this is a bug, # because structured outputs doesn't match up with the outputs... fn_inputs = self.graph.inputs[:-len(self.graph.captures)] # Return filtered ops to the flattened outputs. flat_fn_outputs = nest.flatten(self.graph.structured_outputs) for index, op in returned_ops.items(): flat_fn_outputs[index] = op fn_outputs = nest.pack_sequence_as(self.graph.structured_outputs, flat_fn_outputs) name = name or fn.__name__ wrapped_function = self._wrapped_function.prune(fn_inputs, fn_outputs, name) self._functions[name] = wrapped_function return wrapped_function
def _create_grad_func(func_graph, grads, name): """Returns the FuncGraph representation of _grad_fn.""" return func_graph_module.func_graph_from_py_func( name, lambda: _grad_fn(func_graph, grads), [], {}, func_graph=_CondGradFuncGraph(name, func_graph))
def _create_grad_func(ys, xs, grads, cond_graph, body_graph, name, while_op, maximum_iterations): """Builds and returns the gradient FuncGraph of `func_graph` and its args. The returned grad_func_graph must be called with the returned args + grad_func_graph.captures. Args: ys: A `Tensor` or list of tensors to be differentiated. xs: A `Tensor` or list of tensors to be used for differentiation. grads: The incoming grads for `ys`. cond_graph: FuncGraph for the forward cond function. body_graph: FuncGraph for the forward body function. name: Name of the returned gradient function. while_op: The forward While op. maximum_iterations: Tensor. The maximum number of iterations. Returns: 2-tuple of (grad_func_graph, args). """ assert len(ys) == len(grads) total_iters = while_op.outputs[0] counter = constant_op.constant( 0, dtype=total_iters.dtype, name="grad_counter") # Build frozen sets so that we do not have linear time lookups in # `_is_loop_invariant`. Note: `body_graph.inputs` and `body_graph.outputs` # may get updated during gradient computation because we add accumulators to # the forward op. However, those are not loop invariants so wouldn't affect # the output of `_is_loop_invariant`. Also we would never attempt to capture # those accumulators so `_is_loop_invariant` should never receive those new # tensors as args. body_graph_inputs = object_identity.ObjectIdentitySet(body_graph.inputs) body_graph_outputs = object_identity.ObjectIdentitySet(body_graph.outputs) args = [counter, maximum_iterations, total_iters] + list(grads) # Note: The returned function does not have `args` in the list of # `external_captures`. grad_func_graph = func_graph_module.func_graph_from_py_func( name, lambda *args: _grad_fn(ys, xs, args, body_graph), args, {}, func_graph=_WhileBodyGradFuncGraph(name, cond_graph, body_graph, maximum_iterations, while_op, body_graph_inputs, body_graph_outputs)) # Update the list of outputs with tensors corresponding to the captured # tensors. We capture 3 types of tensors when building the grad fn: # 1. Accumulators for forward graph intermediates which are not loop # invariants. The outputs corresponding to these are populated in # `popped_tensor_lists` by `_WhileBodyGradFuncGraph`. # 2. Resources, which are output as is. # 3. Forward graph loop invariants, which are output as is. for external_capture, internal_capture in grad_func_graph.captures: if ops.tensor_id(internal_capture) in grad_func_graph.popped_tensor_lists: new_output = grad_func_graph.popped_tensor_lists[ops.tensor_id( internal_capture)] elif (internal_capture.dtype == dtypes.resource or _is_loop_invariant( external_capture, body_graph_inputs, body_graph_outputs)): new_output = internal_capture else: raise ValueError("Tensor %s which captures %s is in list of " "internal_captures but is not a resource, is not in " "popped_tensor_lists and does not capture a loop " "invariant." % (str(internal_capture), str(external_capture))) grad_func_graph.outputs.append(new_output) grad_func_graph.structured_outputs.append(new_output) return grad_func_graph, args
def wrap_function(self, fn, signature, name=None): """Wraps a TF 1.X function and returns an eager-compatible function. All functions wrapped in the same `WrappedGraph` will have access to the same graph (`tf.get_default_graph` to get the graph object within a function, or `WrappedGraph.graph` to get the graph outside a function). Variables created within the function will be added to the `variables` list. Function inputs: All inputs to the function must be tensors (nested ok), with their shapes and dtypes defined in the `signature` argument. Function outputs: * The 1.X function may return tensors, variables, and ops. The wrapped eager-compatible function will always return tensors in the same nested structure. * Variables are replaced with a tensor containing the latest read values. * Returned ops are executed, and replaced with None. * The order of op execution and variable reads in the return is nondeterministic. For example: ``` def update_var(x): v = tf.Variable(0) op = tf.compat.v1.assign(v, x).op return v, op g = WrappedGraph() fn = g.wrap_function(update_var) read_value, _ = fn(tf.constant(3)) print(read_value.numpy()) # could be 0 or 3 print(g.variables[0].numpy()) # always 3 ``` To ensure that ops in the function are executed (e.g. ops added to the `tf.GraphKeys.UPDATE_OPS` collection), include them in the function returns. Args: fn: a 1.X tensorflow function. signature: a possibly nested sequence of `TensorSpecs` specifying the shapes and dtypes of the arguments. name: an optional string name for the function. The function will be saved with key `name` in the `functions` dictionary. Returns: An eager-compatible function. """ fn_with_filter_and_scope, returned_ops = _filter_returned_ops( self._variable_holder.call_with_variable_creator_scope(fn)) func_graph.func_graph_from_py_func( None, # Name is unused. fn_with_filter_and_scope, args=None, kwargs=None, signature=signature, add_control_dependencies=False, func_graph=self.graph) # This code relies on questional behavior from `func_graph_from_py_func`. # If an existing FuncGraph is passed into the `func_graph` arg, the inputs # and structured outputs are overwritten. Pretty sure this is a bug, # because structured outputs doesn't match up with the outputs... fn_inputs = self.graph.inputs[:-len(self.graph.captures)] # Return filtered ops to the flattened outputs. flat_fn_outputs = nest.flatten(self.graph.structured_outputs) for index, op in returned_ops.items(): flat_fn_outputs[index] = op fn_outputs = nest.pack_sequence_as(self.graph.structured_outputs, flat_fn_outputs) name = name or fn.__name__ wrapped_function = self._wrapped_function.prune( fn_inputs, fn_outputs, name) self._functions[name] = wrapped_function return wrapped_function
def _WhileGrad(op, *grads): # pylint: disable=invalid-name """The gradient of a While op produced by while_loop.""" body_graph = _get_body_graph(op) # Set the incoming gradient of TensorArray handles to None. The gradient # implementation currently assumes all resource tensors correspond to float32 # ResourceVariables, which can lead to runtime shape errors when used with a # TensorArray. This is a workaround until TensorArrays are reimplemented with # TensorLists instead of resources. # Also set the incoming gradient of non-trainable inputs to None. It is # possible that we receive non-None gradients for non-trainable types in # nested while loops because we accumulate outputs of the inner while as # variant tensors which are trainable and hence receive zeros_like tensors in # the gradient pass. The non-trainable tensors then receive the popped zeros # tensor from this zeros variant. The gradient for the loop vars corresponding # to these tensors is None or zeros (this happens only if the loop var is # accumulated as well) in _grad_fn so we reset these. # TODO(b/118712257): Remove the IsTrainable filter once we can handle None # output grads in _grad_fn. grads = [ None if _is_tensor_array_handle(output) or not gradients_impl.IsTrainable(output) else grad for grad, output in zip(grads, op.outputs) ] # Ensure that all non-resource trainable outputs have incoming gradients. assert all(g is not None or o.dtype == dtypes.resource or not gradients_impl.IsTrainable(o) for o, g in zip(op.outputs, grads) ), "All trainable loop vars must receive incoming gradients." # We compute the gradient for the sub-graph between trainable ys and xs # with non-None incoming gradients. We later pad the None's to the list of # outputs. ys, xs, non_none_grads = zip(*[(y, x, grad) for (y, x, grad) in zip( body_graph.outputs, body_graph.inputs, grads) if grad is not None]) body_grad_graph, args = _create_grad_func( ys, xs, non_none_grads, body_graph, util.unique_grad_fn_name(body_graph.name), op) intermediate_tensors = _get_intermediates(body_grad_graph) maximum_iterations = op.get_attr( "_maximum_iterations") if _is_in_xla_context() else None assert not _is_in_xla_context() or maximum_iterations is not None for intermediate_tensor in intermediate_tensors: tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=intermediate_tensor.shape, max_num_elements=maximum_iterations) with body_grad_graph.as_default(): tensor_list_ph = body_grad_graph.capture(tensor_list, whitelisted=True) # Push the intermediate tensor to the tensor list. appended_tensor_list = list_ops.tensor_list_push_back(tensor_list_ph, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_grad_graph.outputs.append(appended_tensor_list) def grad_cond(counter, max_iters, *unused_args): return counter < max_iters loop_vars = args + body_grad_graph.external_captures grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name) cond_grad_graph = func_graph_module.func_graph_from_py_func( grad_cond_name, grad_cond, loop_vars, {}, func_graph=util.WhileCondFuncGraph(grad_cond_name)) _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars)) outputs = gen_functional_ops._while( loop_vars, util.create_new_tf_function(cond_grad_graph), util.create_new_tf_function(body_grad_graph), output_shapes=[t.shape for t in body_grad_graph.outputs], name="%s_grad" % op.name) _copy_handle_data(body_grad_graph.outputs, outputs) util.maybe_set_lowering_attr(outputs[0].op) _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations) # See comment in while_loop. outputs = [array_ops.identity(t) for t in outputs] # Set None as the output gradient for tensors with None input gradient # e.g. TensorArray handles. # outputs[0] is the loop counter. # outputs[1] is the total number of loop iterations. index = 2 none_padded_outputs = [] for g in grads: if g is None: none_padded_outputs.append(None) else: none_padded_outputs.append(outputs[index]) index += 1 return none_padded_outputs
def while_loop(cond, body, loop_vars, shape_invariants=None, maximum_iterations=None, name=None, return_same_structure=True): """Like tf.while_loop, except emits a single While op.""" maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations) # Keep the original loop_vars around to know which args were TensorArrays. orig_loop_vars = loop_vars # Cache its length since we use it at multiple places below. len_orig_loop_vars = len(orig_loop_vars) # Convert TensorArrays to their flow variables. These get converted back to # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and # `wrapped_body` below. loop_vars = list(_tensor_array_to_flow(orig_loop_vars)) loop_vars = nest.map_structure( ops.internal_convert_to_tensor_or_indexed_slices, loop_vars) if shape_invariants is not None: nest.assert_same_structure(orig_loop_vars, shape_invariants) else: shape_invariants = nest.map_structure(lambda t: t.shape, loop_vars) if not name: name = "while" with ops.name_scope(name) as scope: with ops.name_scope(None): cond_name = util.unique_fn_name(scope, "cond") body_name = util.unique_fn_name(scope, "body") loop_counter = constant_op.constant( 0, dtype=maximum_iterations.dtype if maximum_iterations is not None else None, name="loop_counter") # Add loop counter needed for computing gradients. loop_vars = [loop_counter] + loop_vars shape_invariants = type(shape_invariants)([tensor_shape.scalar() ]) + shape_invariants # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = util.in_defun() # Build a `cond` wrapper that can handle the extra counter loop_var. def wrapped_cond(loop_counter, *args): # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. if maximum_iterations is None: return cond(*_pack_sequence_as(orig_loop_vars, args)) else: return math_ops.logical_and( loop_counter < maximum_iterations, cond(*_pack_sequence_as(orig_loop_vars, args))) cond_graph = func_graph_module.func_graph_from_py_func( cond_name, wrapped_cond, loop_vars, {}, signature=_build_signature(loop_vars, shape_invariants), func_graph=util.WhileCondFuncGraph(cond_name), add_control_dependencies=add_control_dependencies) # Add external_captures of cond to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. loop_vars = loop_vars + cond_graph.external_captures shape_invariants = shape_invariants + type(shape_invariants)( [t.shape for t in cond_graph.external_captures]) def wrapped_body(loop_counter, *args): """Loop body augmented with counter update. Args: loop_counter: Loop counter which needs to be incremented in the body. *args: List of args args[:len_orig_loop_vars] - Args for the original loop body. args[len_orig_loop_vars:] - External captures of cond. These get passed through as is. Returns: A list of tensors the same length as args. """ # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. outputs = body( *_pack_sequence_as(orig_loop_vars, args[:len_orig_loop_vars])) if not nest.is_sequence(outputs): outputs = [outputs] # Compare the structure of input and output of body converting the # top-level tuples to list to be compatible with legacy while_loop. nest.assert_same_structure(list(outputs), list(orig_loop_vars)) outputs = _tensor_array_to_flow(outputs) # Return the external_captures of cond_graph as is, i.e., treat them as # loop invariants. # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. return [loop_counter + 1] + list(outputs) + list( args[len_orig_loop_vars:]) body_graph = func_graph_module.func_graph_from_py_func( body_name, wrapped_body, loop_vars, {}, signature=_build_signature(loop_vars, shape_invariants), func_graph=util.WhileBodyFuncGraph(body_name), add_control_dependencies=add_control_dependencies) # Add external captures of body to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. loop_vars = loop_vars + body_graph.external_captures # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. body_graph.outputs.extend(body_graph.internal_captures) # Capture `external_captures` of `body_graph` in `cond_graph` so that it # expects to receive those as arguments. # TODO(b/118457764): Dedup tensors that are captured in both the cond and # body. This logic already exists in cond_v2. with cond_graph.as_default(): for external_capture in body_graph.external_captures: assert external_capture not in cond_graph.captures, ( "Looks like both cond and body are capturing the same tensor %s. " "This is not supported yet. For now consider passing," " this as a loop variable." % str(external_capture)) cond_graph.capture(external_capture) # Export all tensors in the loop body that may be needed for gradient # computation. We do this by accumulating the intermediate values in # TensorLists. intermediate_tensors = _get_intermediates(body_graph) for intermediate_tensor in intermediate_tensors: tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=intermediate_tensor.shape, max_num_elements=maximum_iterations) loop_vars.append(tensor_list) with cond_graph.as_default(): # Add a placeholder to cond_graph's inputs corresponding to the # tensor_list. cond_graph.capture(tensor_list) with body_graph.as_default(): # Push the intermediate tensor to the tensor list. This captures the # `tensor_list` as well. appended_tensor_list = list_ops.tensor_list_push_back( tensor_list, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_graph.outputs.append(appended_tensor_list) # Make sure that the shapes of the loop outputs are compatible with the # shape invariants, or the shapes of the loop vars if the invariants are not # specified. num_flattened_outputs = len(nest.flatten(orig_loop_vars)) _check_shapes_compat( body_graph.outputs[1:1 + num_flattened_outputs], nest.flatten(shape_invariants[1:1 + len_orig_loop_vars]), nest.flatten(loop_vars[1:1 + len_orig_loop_vars])) flattened_loop_vars = nest.flatten(loop_vars) _check_num_inputs_outputs(cond_graph, body_graph, len(flattened_loop_vars)) outputs = gen_functional_ops._while( flattened_loop_vars, util.create_new_tf_function(cond_graph), util.create_new_tf_function(body_graph), output_shapes=[t.shape for t in body_graph.outputs], name=scope) _copy_handle_data(body_graph.outputs, outputs) util.maybe_set_lowering_attr(outputs[0].op) _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations) # Return identities for each output of the While op, rather than the output # of the While op directly. This makes pruning work if the output of # while_loop() is fetched: the lowering pass converts the While outputs into # IdentityN outputs, which if fetched will cause all ops in the body to be # run (since it takes all exit ops as input). After lowering, each output # identity op will end up with only the appropriate exit op as input. outputs = tuple(array_ops.identity(t) for t in outputs) # First var is loop counter. outputs = _pack_sequence_as(orig_loop_vars, outputs[1:1 + num_flattened_outputs]) if return_same_structure: return outputs flattened_outputs = nest.flatten(outputs) if len(flattened_outputs) == 1: return flattened_outputs[0] else: return outputs
def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" if isinstance(pred, bool): raise TypeError("pred must not be a Python bool", pred) if not name: name = "cond" with ops.name_scope(name) as scope: true_name = util.unique_fn_name(scope, "true") false_name = util.unique_fn_name(scope, "false") # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = util.in_defun() pred = ops.convert_to_tensor(pred) true_graph = func_graph_module.func_graph_from_py_func( true_name, true_fn, [], {}, func_graph=util.CondBranchFuncGraph( true_name, read_only_collections=False), add_control_dependencies=add_control_dependencies, op_return_value=pred) false_graph = func_graph_module.func_graph_from_py_func( false_name, false_fn, [], {}, func_graph=util.CondBranchFuncGraph( false_name, read_only_collections=False), add_control_dependencies=add_control_dependencies, op_return_value=pred) _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_graph.external_captures, false_graph.external_captures) # Add all intermediate tensors as function outputs so they're available for # the gradient computation. true_intermediates = _get_intermediates(true_graph) false_intermediates = _get_intermediates(false_graph) # Save the original number of outputs to return to the caller. num_cond_outputs = len(true_graph.outputs) # Make the number/type of new intermediate outputs match. extra_true_outputs, extra_false_outputs = _pad_params( true_graph, false_graph, true_intermediates, false_intermediates) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # Create the If op. tensors = gen_functional_ops._if( # pylint: disable=protected-access pred, cond_inputs, [t.dtype for t in true_graph.outputs], util.create_new_tf_function(true_graph), util.create_new_tf_function(false_graph), output_shapes=_get_output_shapes(true_graph.outputs, false_graph.outputs), name=scope) # Set the flag to enable lowering on the `if` op if necessary # Lowering allows cond_v2 to avoid some of the limitations of Functions, # allowing users to specify devices & colocation inside of cond_v2 branches, # and enabling non-strict evaluation & partial pruning of cond_v2 branches. # This brings cond_v2 closer to feature parity with tf.cond. # # However, we do not lower `If` in the XLA context because it is easier for # XLA to apply its own optimizations when dealing with un-lowered `If` # operators than with lowered switch/merge control flow. # # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output if_op = tensors[0].op if not control_flow_util.IsInXLAContext(if_op): # pylint: disable=protected-access if_op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True)) # pylint: enable=protected-access # Return identities for each output of the If op, rather than the output of # the If op directly. This makes pruning work if the output of cond() is # fetched: the lowering pass converts the If outputs into IdentityN outputs, # which if fetched will cause all ops in the taken branch to be run (since # it takes all merge ops as input). After lowering, each output identity op # will end up with only the appropriate merge op as input. # TODO(b/79984175): this doesn't have to be a tuple once we covert to the # correct output structure tensors = tuple(array_ops.identity(t) for t in tensors) result = tuple(tensors[:num_cond_outputs]) if len(result) == 1: return result[0] else: return result
def while_loop(cond, body, loop_vars, shape_invariants=None, parallel_iterations=10, maximum_iterations=None, name=None, return_same_structure=True): """Like tf.while_loop, except emits a single While op.""" # Keep the original loop_vars around to know which args were TensorArrays. orig_loop_vars = loop_vars # Cache its length since we use it at multiple places below. len_orig_loop_vars = len(orig_loop_vars) # Convert TensorArrays to their flow variables. These get converted back to # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and # `wrapped_body` below. loop_vars = list(_tensor_array_to_flow(orig_loop_vars)) loop_vars = nest.map_structure( ops.internal_convert_to_tensor_or_indexed_slices, loop_vars) if shape_invariants is not None: nest.assert_same_structure(orig_loop_vars, shape_invariants) else: shape_invariants = nest.map_structure(lambda t: t.shape, loop_vars) if not name: name = "while" with ops.name_scope(name) as scope: with ops.name_scope(None): cond_name = util.unique_fn_name(scope, "cond") body_name = util.unique_fn_name(scope, "body") maximum_iterations_loop_var = _build_maximum_iterations_loop_var( maximum_iterations) loop_counter = constant_op.constant( 0, dtype=maximum_iterations_loop_var.dtype if maximum_iterations is not None else None, name="loop_counter") # Add loop counter needed for computing gradients. loop_vars = [loop_counter, maximum_iterations_loop_var] + loop_vars shape_invariants = type(shape_invariants)( [tensor_shape.scalar(), tensor_shape.scalar()]) + shape_invariants # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = ops.get_default_graph( )._add_control_dependencies # Build a `cond` wrapper that can handle the extra counter loop_var. def wrapped_cond(loop_counter, maximum_iterations_arg, *args): # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. if maximum_iterations is None: return cond(*_pack_sequence_as(orig_loop_vars, args)) else: return math_ops.logical_and( loop_counter < maximum_iterations_arg, cond(*_pack_sequence_as(orig_loop_vars, args))) # NOTE(skyewm): we set collections to the outer graph's collections for # compatibility with TPUEstimator. cond_graph = func_graph_module.func_graph_from_py_func( cond_name, wrapped_cond, [], # We provide signature instead of args. {}, signature=_build_signature(loop_vars, shape_invariants), func_graph=util.WhileCondFuncGraph( cond_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies) def wrapped_body(loop_counter, maximum_iterations_arg, *args): """Loop body augmented with counter update. Args: loop_counter: Loop counter which needs to be incremented in the body. maximum_iterations_arg: Maximum iterations of the loop. *args: List of args Returns: A list of tensors the same length as args. """ # Capture the tensors already captured in cond_graph so that they appear # in the same order in body_graph.external_captures. for t in cond_graph.external_captures: ops.get_default_graph().capture(t) # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. outputs = body(*_pack_sequence_as(orig_loop_vars, args)) if not nest.is_sequence(outputs): outputs = [outputs] # Compare the structure of input and output of body converting the # top-level tuples to list to be compatible with legacy while_loop. nest.assert_same_structure(list(outputs), list(orig_loop_vars)) outputs = _tensor_array_to_flow(outputs) # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. return [loop_counter + 1, maximum_iterations_arg] + list(outputs) body_graph = func_graph_module.func_graph_from_py_func( body_name, wrapped_body, [], # We provide signature instead of args. {}, signature=_build_signature(loop_vars, shape_invariants), func_graph=util.WhileBodyFuncGraph( body_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies) # Add external captures of body to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. loop_vars = loop_vars + body_graph.external_captures # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. body_graph.outputs.extend(body_graph.internal_captures) # Capture the extra `external_captures` of `body_graph` in `cond_graph` so # that it expects to receive those as arguments. with cond_graph.as_default(): num_cond_captures = len(cond_graph.external_captures) assert (cond_graph.external_captures == body_graph.external_captures[:num_cond_captures]) for body_capture in body_graph.external_captures[ num_cond_captures:]: assert body_capture not in cond_graph.captures cond_graph.capture(body_capture) # Make sure that the shapes of the loop outputs are compatible with the # shape invariants, or the shapes of the loop vars if the invariants are not # specified. num_flattened_outputs = len(nest.flatten(orig_loop_vars)) # First var is loop counter and second var is maximum_iterations. first_loop_var_index = 2 _check_shapes_compat( body_graph.outputs[first_loop_var_index:first_loop_var_index + num_flattened_outputs], nest.flatten( shape_invariants[first_loop_var_index:first_loop_var_index + len_orig_loop_vars]), nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index + len_orig_loop_vars])) flattened_loop_vars = nest.flatten(loop_vars) _check_num_inputs_outputs(cond_graph, body_graph, len(flattened_loop_vars)) with ops.control_dependencies( list(cond_graph.control_captures) + list(body_graph.control_captures)): outputs = gen_functional_ops._while( flattened_loop_vars, util.create_new_tf_function(cond_graph), util.create_new_tf_function(body_graph), output_shapes=[t.shape for t in body_graph.outputs], parallel_iterations=parallel_iterations, name=scope) _copy_handle_data(body_graph.outputs, outputs) util.maybe_set_lowering_attr(outputs[0].op) util.maybe_propagate_compile_time_consts_in_xla(outputs[0].op) # Return identities for each output of the While op, rather than the output # of the While op directly. This makes pruning work if the output of # while_loop() is fetched: the lowering pass converts the While outputs into # IdentityN outputs, which if fetched will cause all ops in the body to be # run (since it takes all exit ops as input). After lowering, each output # identity op will end up with only the appropriate exit op as input. outputs = tuple(array_ops.identity(t) for t in outputs) outputs = _pack_sequence_as( orig_loop_vars, outputs[first_loop_var_index:first_loop_var_index + num_flattened_outputs]) if return_same_structure: return outputs flattened_outputs = nest.flatten(outputs) if len(flattened_outputs) == 1: return flattened_outputs[0] else: return outputs
def while_loop(cond, body, loop_vars, shape_invariants=None, parallel_iterations=10, maximum_iterations=None, name=None, return_same_structure=True): """Like tf.while_loop, except emits a single While op.""" maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations) # Keep the original loop_vars around to know which args were TensorArrays. orig_loop_vars = loop_vars # Cache its length since we use it at multiple places below. len_orig_loop_vars = len(orig_loop_vars) # Convert TensorArrays to their flow variables. These get converted back to # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and # `wrapped_body` below. loop_vars = list(_tensor_array_to_flow(orig_loop_vars)) loop_vars = nest.map_structure( ops.internal_convert_to_tensor_or_indexed_slices, loop_vars) if shape_invariants is not None: nest.assert_same_structure(orig_loop_vars, shape_invariants) else: shape_invariants = nest.map_structure(lambda t: t.shape, loop_vars) if not name: name = "while" with ops.name_scope(name) as scope: with ops.name_scope(None): cond_name = util.unique_fn_name(scope, "cond") body_name = util.unique_fn_name(scope, "body") loop_counter = constant_op.constant( 0, dtype=maximum_iterations.dtype if maximum_iterations is not None else None, name="loop_counter") # Add loop counter needed for computing gradients. loop_vars = [loop_counter] + loop_vars shape_invariants = type(shape_invariants)([tensor_shape.scalar() ]) + shape_invariants # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = ops.get_default_graph()._add_control_dependencies # Build a `cond` wrapper that can handle the extra counter loop_var. def wrapped_cond(loop_counter, *args): # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. if maximum_iterations is None: return cond(*_pack_sequence_as(orig_loop_vars, args)) else: return math_ops.logical_and( loop_counter < maximum_iterations, cond(*_pack_sequence_as(orig_loop_vars, args))) # NOTE(skyewm): we set collections to the outer graph's collections for # compatibility with TPUEstimator. cond_graph = func_graph_module.func_graph_from_py_func( cond_name, wrapped_cond, [], # We provide signature instead of args. {}, signature=_build_signature(loop_vars, shape_invariants), func_graph=util.WhileCondFuncGraph( cond_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies) def wrapped_body(loop_counter, *args): """Loop body augmented with counter update. Args: loop_counter: Loop counter which needs to be incremented in the body. *args: List of args Returns: A list of tensors the same length as args. """ # Capture the tensors already captured in cond_graph so that they appear # in the same order in body_graph.external_captures. for t in cond_graph.external_captures: ops.get_default_graph().capture(t) # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. outputs = body(*_pack_sequence_as(orig_loop_vars, args)) if not nest.is_sequence(outputs): outputs = [outputs] # Compare the structure of input and output of body converting the # top-level tuples to list to be compatible with legacy while_loop. nest.assert_same_structure(list(outputs), list(orig_loop_vars)) outputs = _tensor_array_to_flow(outputs) # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. return [loop_counter + 1] + list(outputs) body_graph = func_graph_module.func_graph_from_py_func( body_name, wrapped_body, [], # We provide signature instead of args. {}, signature=_build_signature(loop_vars, shape_invariants), func_graph=util.WhileBodyFuncGraph( body_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies) # Add external captures of body to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. loop_vars = loop_vars + body_graph.external_captures # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. body_graph.outputs.extend(body_graph.internal_captures) # Capture the extra `external_captures` of `body_graph` in `cond_graph` so # that it expects to receive those as arguments. with cond_graph.as_default(): num_cond_captures = len(cond_graph.external_captures) assert (cond_graph.external_captures == body_graph.external_captures[:num_cond_captures]) for body_capture in body_graph.external_captures[num_cond_captures:]: assert body_capture not in cond_graph.captures cond_graph.capture(body_capture) # Make sure that the shapes of the loop outputs are compatible with the # shape invariants, or the shapes of the loop vars if the invariants are not # specified. num_flattened_outputs = len(nest.flatten(orig_loop_vars)) _check_shapes_compat( body_graph.outputs[1:1 + num_flattened_outputs], nest.flatten(shape_invariants[1:1 + len_orig_loop_vars]), nest.flatten(loop_vars[1:1 + len_orig_loop_vars])) flattened_loop_vars = nest.flatten(loop_vars) _check_num_inputs_outputs(cond_graph, body_graph, len(flattened_loop_vars)) outputs = gen_functional_ops._while( flattened_loop_vars, util.create_new_tf_function(cond_graph), util.create_new_tf_function(body_graph), output_shapes=[t.shape for t in body_graph.outputs], parallel_iterations=parallel_iterations, name=scope) _copy_handle_data(body_graph.outputs, outputs) util.maybe_set_lowering_attr(outputs[0].op) _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations) # Return identities for each output of the While op, rather than the output # of the While op directly. This makes pruning work if the output of # while_loop() is fetched: the lowering pass converts the While outputs into # IdentityN outputs, which if fetched will cause all ops in the body to be # run (since it takes all exit ops as input). After lowering, each output # identity op will end up with only the appropriate exit op as input. outputs = tuple(array_ops.identity(t) for t in outputs) # First var is loop counter. outputs = _pack_sequence_as(orig_loop_vars, outputs[1:1 + num_flattened_outputs]) if return_same_structure: return outputs flattened_outputs = nest.flatten(outputs) if len(flattened_outputs) == 1: return flattened_outputs[0] else: return outputs
def _WhileGrad(op, *grads): # pylint: disable=invalid-name """The gradient of a While op produced by while_loop.""" # Note that op is not always the same as while_op because the gradient tape, # for eager mode compatibility, forgets information about the proper op. Since # the loop cannot run in eager mode, however, we can safely introspect into # the graph here. while_op = op.outputs[0].op cond_graph = _get_graph(while_op, "cond") body_graph = _get_graph(while_op, "body") orig_num_params = len(body_graph.outputs) maximum_iterations = op.inputs[1] parallel_iterations = op.get_attr("parallel_iterations") try: num_original_outputs = while_op.get_attr("_num_original_outputs") except: # pylint: disable=bare-except num_original_outputs = len(while_op.outputs) num_intermediates = len(while_op.outputs) - num_original_outputs grads = [ _preprocess_grad(grad, body_out, while_out) # pylint: disable=g-complex-comprehension for grad, body_out, while_out in zip( grads[:num_original_outputs], body_graph.outputs[:num_original_outputs], while_op.outputs[:num_original_outputs]) ] + [None] * num_intermediates # We compute the gradient for the sub-graph between trainable ys and xs # with non-None incoming gradients. We later pad the None's to the list of # outputs. ys, xs, non_none_grads = zip(*[(y, x, grad) for (y, x, grad) in zip( body_graph.outputs, body_graph.inputs, grads) if grad is not None]) body_grad_graph, args = _create_grad_func( ys, xs, non_none_grads, cond_graph, body_graph, util.unique_grad_fn_name(body_graph.name), op, maximum_iterations) if body_grad_graph.while_op_needs_rewrite: # Modify 'op' to output the intermediate accumulators needed by the grad # function. # NOTE(skyewm): if there are any active sessions, this modification to `op` # may make them unrunnable! cond_graph.name += "_rewritten" body_graph.name += "_rewritten" new_inputs = body_grad_graph.empty_tensor_lists new_outputs = body_graph.outputs[orig_num_params:] while_op._set_func_attr("cond", util.create_new_tf_function(cond_graph)) while_op._set_func_attr("body", util.create_new_tf_function(body_graph)) while_op._set_type_list_attr("T", body_graph.output_types) while_op._set_shape_list_attr("output_shapes", body_graph.output_shapes) while_op._add_while_inputs(new_inputs) while_op._add_outputs([t.dtype for t in new_outputs], [t.shape for t in new_outputs]) _copy_handle_data(new_outputs, op.outputs[orig_num_params:]) # Do not ingore grads wrt extra outputs when computing higher order # derivatives. while_op._set_attr("_num_original_outputs", attr_value_pb2.AttrValue(i=len(while_op.outputs))) captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph, while_op) loop_vars = args + captured_inputs # This modifies body_grad_graph. loop_vars = while_v2_indexed_slices_rewriter.rewrite_grad_indexed_slices( grads, body_grad_graph, loop_vars, while_op.inputs) def grad_cond(counter, unused_maximum_iterations_arg, forward_loop_iters, *unused_args): return counter < forward_loop_iters grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name) cond_grad_graph = func_graph_module.func_graph_from_py_func( grad_cond_name, grad_cond, loop_vars, {}, func_graph=util.WhileCondFuncGraph(grad_cond_name)) _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars)) outputs = gen_functional_ops._while( loop_vars, util.create_new_tf_function(cond_grad_graph), util.create_new_tf_function(body_grad_graph), output_shapes=[t.shape for t in body_grad_graph.outputs], parallel_iterations=parallel_iterations, name="%s_grad" % while_op.name) grad_op = outputs[0].op _copy_handle_data(body_grad_graph.outputs, outputs) util.maybe_set_lowering_attr(grad_op) util.maybe_propagate_compile_time_consts_in_xla(grad_op) # See comment in while_loop. outputs = [array_ops.identity(t) for t in outputs] return _get_structured_grad_output(outputs, grads, body_grad_graph)
def while_loop(cond, body, loop_vars, shape_invariants=None, name=None): """Like tf.while_loop, except emits a single While op.""" flattened_loop_vars = nest.flatten(loop_vars) if shape_invariants is not None: nest.assert_same_structure(loop_vars, shape_invariants) flattened_shapes = nest.flatten(shape_invariants) else: flattened_shapes = [t.shape for t in flattened_loop_vars] del shape_invariants if not name: name = "while" with ops.name_scope(name) as scope: with ops.name_scope(None): cond_name = util.unique_fn_name(scope, "cond") body_name = util.unique_fn_name(scope, "body") num_outputs = len(flattened_loop_vars) # Add loop counter needed for computing gradients. flattened_loop_vars = [constant_op.constant(0., name="loop_counter") ] + flattened_loop_vars flattened_shapes = [tensor_shape.scalar()] + flattened_shapes # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = util.in_defun() # Build a `cond` wrapper that can handle the extra counter loop_var. def wrapped_cond(unused_loop_counter, *loop_vars): return cond(*loop_vars) signature = [ tensor_spec.TensorSpec(shape, t.dtype) for shape, t in zip(flattened_shapes, flattened_loop_vars) ] cond_graph = func_graph_module.func_graph_from_py_func( cond_name, wrapped_cond, flattened_loop_vars, {}, signature=signature, func_graph=util.WhileCondFuncGraph(cond_name), add_control_dependencies=add_control_dependencies) # Add external_captures of cond to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. flattened_loop_vars = flattened_loop_vars + cond_graph.external_captures flattened_shapes = flattened_shapes + [ t.shape for t in cond_graph.external_captures ] def wrapped_body(loop_counter, *args): """Loop body augmented with counter update. Args: loop_counter: Loop counter which needs to be incremented in the body. *args: List of args args[:num_outputs] - Args for the original loop body. args[num_outputs:] - External captures of cond. These get passed through as is. Returns: A list of tensors the same length as args. """ outputs = body(*args[:num_outputs]) if not isinstance(outputs, collections.Sequence): outputs = [outputs] # Return the external_captures of cond_graph as is, i.e., treat them as # loop invariants. # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. return [loop_counter + 1] + list(outputs) + list(args[num_outputs:]) signature = [ tensor_spec.TensorSpec(shape, t.dtype) for shape, t in zip(flattened_shapes, flattened_loop_vars) ] body_graph = func_graph_module.func_graph_from_py_func( body_name, wrapped_body, flattened_loop_vars, {}, signature=signature, func_graph=util.WhileBodyFuncGraph(body_name), add_control_dependencies=add_control_dependencies) # Add external captures of body to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. flattened_loop_vars = flattened_loop_vars + body_graph.external_captures # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. body_graph.outputs.extend(body_graph.internal_captures) # Capture `external_captures` of `body_graph` in `cond_graph` so that it # expects to receive those as arguments. # TODO(srbs): Dedup tensors that are captured in both the cond and body. # This logic already exists in cond_v2. with cond_graph.as_default(): for external_capture in body_graph.external_captures: cond_graph.capture(external_capture) # Export all tensors in the loop body that may be needed for gradient # computation. We do this by accumulating the intermediate values in # TensorLists. intermediate_tensors = _get_intermediates(body_graph) for intermediate_tensor in intermediate_tensors: # TODO(srbs): Cache and re-use empty tensor lists. tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=_get_tensor_convertible_shape( intermediate_tensor.shape)) flattened_loop_vars.append(tensor_list) with cond_graph.as_default(): # Add a placeholder to cond_graph's inputs corresponding to the # tensor_list. cond_graph.capture(tensor_list) with body_graph.as_default(): # Push the intermediate tensor to the tensor list. This captures the # `tensor_list` as well. appended_tensor_list = list_ops.tensor_list_push_back( tensor_list, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_graph.outputs.append(appended_tensor_list) # Make sure that the shapes of the loop outputs are compatible with the # shape invariants, or the shapes of the loop vars if the invariants are not # specified. _check_shapes_compat(body_graph.outputs[1:1 + num_outputs], flattened_shapes[1:1 + num_outputs], flattened_loop_vars[1:1 + num_outputs]) outputs = gen_functional_ops._while( flattened_loop_vars, util.create_new_tf_function(cond_graph), util.create_new_tf_function(body_graph), output_shapes=[t.shape for t in body_graph.outputs], name=scope) _copy_handle_data(body_graph.outputs, outputs) _maybe_set_lowering_attr(outputs[0].op) # Return identities for each output of the While op, rather than the output # of the While op directly. This makes pruning work if the output of # while_loop() is fetched: the lowering pass converts the While outputs into # IdentityN outputs, which if fetched will cause all ops in the body to be # run (since it takes all exit ops as input). After lowering, each output # identity op will end up with only the appropriate exit op as input. outputs = tuple(array_ops.identity(t) for t in outputs) # First var is loop counter. if num_outputs == 1: return outputs[1] else: return nest.pack_sequence_as(loop_vars, outputs[1:1 + num_outputs])
def while_loop(cond, body, loop_vars, shape_invariants=None, parallel_iterations=10, maximum_iterations=None, name=None, return_same_structure=True, back_prop=True): """Like tf.while_loop, except emits a single While op.""" # Keep the original loop_vars around to know which args were TensorArrays. orig_loop_vars = loop_vars # Cache its length since we use it at multiple places below. len_orig_loop_vars = len(orig_loop_vars) # Convert TensorArrays to their flow variables. These get converted back to # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and # `wrapped_body` below. loop_vars = list(_tensor_array_to_flow(orig_loop_vars)) loop_vars = nest.map_structure( ops.internal_convert_to_tensor_or_indexed_slices, loop_vars, expand_composites=True) if shape_invariants is not None: nest.assert_same_structure(orig_loop_vars, shape_invariants, expand_composites=False) signature = nest.map_structure( control_flow_ops._shape_invariant_to_type_spec, loop_vars, list(shape_invariants), expand_composites=False) shape_invariants = nest.map_structure( control_flow_ops._get_shape_invariant, loop_vars, list(shape_invariants), expand_composites=False) else: signature = nest.map_structure( type_spec.type_spec_from_value, loop_vars, expand_composites=False) shape_invariants = nest.map_structure( control_flow_ops._get_shape_invariant, loop_vars, expand_composites=False) if not name: name = "while" with ops.name_scope(name) as scope: with ops.name_scope(None): cond_name = util.unique_fn_name(scope, "cond") body_name = util.unique_fn_name(scope, "body") maximum_iterations_loop_var = _build_maximum_iterations_loop_var( maximum_iterations) loop_counter = constant_op.constant( 0, dtype=maximum_iterations_loop_var.dtype if maximum_iterations is not None else None, name="loop_counter") # Add loop counter needed for computing gradients. loop_vars = [loop_counter, maximum_iterations_loop_var] + loop_vars shape_invariants = [tensor_shape.TensorShape([])] * 2 + shape_invariants signature = ( [tensor_spec.TensorSpec.from_tensor(loop_counter), tensor_spec.TensorSpec.from_tensor(maximum_iterations_loop_var)] + signature) # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = ops.get_default_graph()._add_control_dependencies def wrapped_cond(loop_counter, maximum_iterations_arg, *args): """Extra `cond` wrapper that can handle the extra counter loop_var.""" # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. pred = cond(*_pack_sequence_as(orig_loop_vars, args)) if (tensor_util.is_tensor(pred) and (pred.shape.dims is None or pred.shape.dims)): pred = array_ops.squeeze_v2(pred) if maximum_iterations is None: return pred else: return math_ops.logical_and( loop_counter < maximum_iterations_arg, pred) # NOTE(skyewm): we set collections to the outer graph's collections for # compatibility with TPUEstimator. cond_graph = func_graph_module.func_graph_from_py_func( cond_name, wrapped_cond, [], # We provide signature instead of args. {}, signature=signature, func_graph=util.WhileCondFuncGraph( cond_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies) def wrapped_body(loop_counter, maximum_iterations_arg, *args): """Loop body augmented with counter update. Args: loop_counter: Loop counter which needs to be incremented in the body. maximum_iterations_arg: Maximum iterations of the loop. *args: List of args Returns: A list of tensors the same length as args. """ # Capture the tensors already captured in cond_graph so that they appear # in the same order in body_graph.external_captures. for t in cond_graph.external_captures: ops.get_default_graph().capture(t) # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. outputs = body(*_pack_sequence_as(orig_loop_vars, args)) if not nest.is_sequence_or_composite(outputs): outputs = [outputs] # Compare the structure of input and output of body converting the # top-level tuples to list to be compatible with legacy while_loop. nest.assert_same_structure(list(outputs), list(orig_loop_vars), expand_composites=True) outputs = _tensor_array_to_flow(outputs) # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. return [loop_counter + 1, maximum_iterations_arg] + list(outputs) body_graph = func_graph_module.func_graph_from_py_func( body_name, wrapped_body, [], # We provide signature instead of args. {}, signature=signature, func_graph=util.WhileBodyFuncGraph( body_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies) # Add external captures of body to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. loop_vars = loop_vars + body_graph.external_captures # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. body_graph.outputs.extend(body_graph.internal_captures) # Capture the extra `external_captures` of `body_graph` in `cond_graph` so # that it expects to receive those as arguments. with cond_graph.as_default(): num_cond_captures = len(cond_graph.external_captures) assert (cond_graph.external_captures == body_graph.external_captures[:num_cond_captures]) cond_graph_captures = object_identity.ObjectIdentitySet( cond_graph.external_captures) for body_capture in body_graph.external_captures[num_cond_captures:]: assert body_capture not in cond_graph_captures cond_graph.capture(body_capture) # Make sure that the shapes of the loop outputs are compatible with the # shape invariants, or the shapes of the loop vars if the invariants are not # specified. num_flattened_outputs = len(nest.flatten(orig_loop_vars, expand_composites=True)) # First var is loop counter and second var is maximum_iterations. first_loop_var_index = 2 _check_shapes_compat( body_graph.outputs[first_loop_var_index:first_loop_var_index + num_flattened_outputs], nest.flatten( shape_invariants[first_loop_var_index:first_loop_var_index + len_orig_loop_vars], expand_composites=True), nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index + len_orig_loop_vars], expand_composites=True)) num_original_outputs = len(body_graph.outputs) if back_prop and util.output_all_intermediates(): # Export all tensors in the loop body that may be needed for gradient # computation. We do this by accumulating the intermediate values in # TensorLists. intermediate_tensors = _get_intermediates(body_graph) for intermediate_tensor in intermediate_tensors: tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=intermediate_tensor.shape, max_num_elements=maximum_iterations) loop_vars.append(tensor_list) with cond_graph.as_default(): # Add a placeholder to cond_graph's inputs corresponding to the # tensor_list. cond_graph.capture(tensor_list) with body_graph.as_default(): # Push the intermediate tensor to the tensor list. This captures the # `tensor_list` as well. appended_tensor_list = list_ops.tensor_list_push_back( tensor_list, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_graph.outputs.append(appended_tensor_list) flattened_loop_vars = nest.flatten(loop_vars, expand_composites=True) _check_num_inputs_outputs(cond_graph, body_graph, len(flattened_loop_vars)) _check_inputs_outputs_types_match(body_graph, flattened_loop_vars) with ops.control_dependencies( list(cond_graph.control_captures) + list(body_graph.control_captures)): output_shapes = [t.shape for t in body_graph.outputs] orig_loop_vars_range = slice(first_loop_var_index, first_loop_var_index + num_flattened_outputs) output_shapes[orig_loop_vars_range] = nest.flatten( shape_invariants, expand_composites=True)[orig_loop_vars_range] cond_stateful_ops = [ op for op in cond_graph.get_operations() if op._is_stateful ] body_stateful_ops = [ op for op in body_graph.get_operations() if op._is_stateful ] if (cond_stateful_ops or body_stateful_ops): op_fn = gen_functional_ops._while else: op_fn = gen_functional_ops.stateless_while outputs = op_fn( flattened_loop_vars, util.create_new_tf_function(cond_graph), util.create_new_tf_function(body_graph), output_shapes=output_shapes, parallel_iterations=parallel_iterations, name=scope) # This is needed so we do not compute derivative wrt these extra outputs. outputs[0].op._set_attr("_num_original_outputs", attr_value_pb2.AttrValue(i=num_original_outputs)) _copy_handle_data(body_graph.outputs, outputs) util.maybe_set_lowering_attr(outputs[0].op) util.maybe_propagate_compile_time_consts_in_xla(outputs[0].op) # Return identities for each output of the While op, rather than the output # of the While op directly. This makes pruning work if the output of # while_loop() is fetched: the lowering pass converts the While outputs into # IdentityN outputs, which if fetched will cause all ops in the body to be # run (since it takes all exit ops as input). After lowering, each output # identity op will end up with only the appropriate exit op as input. outputs = tuple(array_ops.identity(t) for t in outputs) outputs = _pack_sequence_as( orig_loop_vars, outputs[first_loop_var_index:first_loop_var_index + num_flattened_outputs]) if return_same_structure: return outputs flattened_outputs = nest.flatten(outputs, expand_composites=True) if len(flattened_outputs) == 1: return flattened_outputs[0] else: return outputs
def _WhileGrad(op, *grads): # pylint: disable=invalid-name """The gradient of a While op produced by while_loop.""" body_graph = _get_body_graph(op) # Set the incoming gradient of TensorArray handles to None. The gradient # implementation currently assumes all resource tensors correspond to float32 # ResourceVariables, which can lead to runtime shape errors when used with a # TensorArray. This is a workaround until TensorArrays are reimplemented with # TensorLists instead of resources. # Also set the incoming gradient of non-trainable inputs to None. It is # possible that we receive non-None gradients for non-trainable types in # nested while loops because we accumulate outputs of the inner while as # variant tensors which are trainable and hence receive zeros_like tensors in # the gradient pass. The non-trainable tensors then receive the popped zeros # tensor from this zeros variant. The gradient for the loop vars corresponding # to these tensors is None or zeros (this happens only if the loop var is # accumulated as well) in _grad_fn so we reset these. # TODO(b/118712257): Remove the IsTrainable filter once we can handle None # output grads in _grad_fn. grads = [ None if _is_tensor_array_handle(output) or not gradients_impl.IsTrainable(output) else grad for grad, output in zip(grads, op.outputs) ] # Ensure that all non-resource trainable outputs have incoming gradients. assert all(g is not None or o.dtype == dtypes.resource or not gradients_impl.IsTrainable(o) for o, g in zip(op.outputs, grads) ), "All trainable loop vars must receive incoming gradients." # We compute the gradient for the sub-graph between trainable ys and xs # with non-None incoming gradients. We later pad the None's to the list of # outputs. ys, xs, non_none_grads = zip( *[(y, x, grad) for (y, x, grad) in zip(body_graph.outputs, body_graph.inputs, grads) if grad is not None]) body_grad_graph, args = _create_grad_func( ys, xs, non_none_grads, body_graph, util.unique_grad_fn_name(body_graph.name), op) intermediate_tensors = _get_intermediates(body_grad_graph) maximum_iterations = op.get_attr( "_maximum_iterations") if _is_in_xla_context() else None assert not _is_in_xla_context() or maximum_iterations is not None for intermediate_tensor in intermediate_tensors: tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=intermediate_tensor.shape, max_num_elements=maximum_iterations) with body_grad_graph.as_default(): tensor_list_ph = body_grad_graph.capture(tensor_list, whitelisted=True) # Push the intermediate tensor to the tensor list. appended_tensor_list = list_ops.tensor_list_push_back( tensor_list_ph, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_grad_graph.outputs.append(appended_tensor_list) def grad_cond(counter, max_iters, *unused_args): return counter < max_iters loop_vars = args + body_grad_graph.external_captures grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name) cond_grad_graph = func_graph_module.func_graph_from_py_func( grad_cond_name, grad_cond, loop_vars, {}, func_graph=util.WhileCondFuncGraph(grad_cond_name)) _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars)) outputs = gen_functional_ops._while( loop_vars, util.create_new_tf_function(cond_grad_graph), util.create_new_tf_function(body_grad_graph), output_shapes=[t.shape for t in body_grad_graph.outputs], name="%s_grad" % op.name) _copy_handle_data(body_grad_graph.outputs, outputs) util.maybe_set_lowering_attr(outputs[0].op) _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations) # See comment in while_loop. outputs = [array_ops.identity(t) for t in outputs] # Set None as the output gradient for tensors with None input gradient # e.g. TensorArray handles. # outputs[0] is the loop counter. # outputs[1] is the total number of loop iterations. index = 2 none_padded_outputs = [] for g in grads: if g is None: none_padded_outputs.append(None) else: none_padded_outputs.append(outputs[index]) index += 1 return none_padded_outputs
def wrap_function(fn, signature, name=None): """Wraps the TF 1.x function fn into a graph function. The python function `fn` will be called once with symbolic arguments specified in the `signature`, traced, and turned into a graph function. Any variables created by `fn` will be owned by the object returned by `wrap_function`. The resulting graph function can be called with tensors which match the signature. ```python def f(x, do_add): v = tf.Variable(5.0) if do_add: op = v.assign_add(x) else: op = v.assign_sub(x) with tf.control_dependencies([op]): return v.read_value() f_add = tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), True]) assert float(f_add(1.0)) == 6.0 assert float(f_add(1.0)) == 7.0 # Can call tf.compat.v1.wrap_function again to get a new trace, a new set # of variables, and possibly different non-template arguments. f_sub= tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), False]) assert float(f_sub(1.0)) == 4.0 assert float(f_sub(1.0)) == 3.0 ``` Both `tf.compat.v1.wrap_function` and `tf.function` create a callable TensorFlow graph. But while `tf.function` runs all stateful operations (e.g. `tf.print`) and sequences operations to provide the same semantics as eager execution, `wrap_function` is closer to the behavior of `session.run` in TensorFlow 1.x. It will not run any operations unless they are required to compute the function's outputs, either through a data dependency or a control dependency. Nor will it sequence operations. Unlike `tf.function`, `wrap_function` will only trace the Python function once. As with placeholders in TF 1.x, shapes and dtypes must be provided to `wrap_function`'s `signature` argument. Since it is only traced once, variables and state may be created inside the function and owned by the function wrapper object. Args: fn: python function to be wrapped signature: the placeholder and python arguments to be passed to the wrapped function name: Optional. The name of the function. Returns: the wrapped graph function. """ holder = VariableHolder(fn) func_graph_name = "wrapped_function" if name is not None: func_graph_name = "wrapped_function_" + name return WrappedFunction(func_graph.func_graph_from_py_func( func_graph_name, holder, args=None, kwargs=None, signature=signature, add_control_dependencies=False, collections={}), variable_holder=holder, signature=signature)
def _WhileGrad(op, *grads): # pylint: disable=invalid-name """The gradient of a While op produced by while_loop.""" # Note that op is not always the same as while_op because the gradient tape, # for eager mode compatibility, forgets information about the proper op. Since # the loop cannot run in eager mode, however, we can safely introspect into # the graph here. while_op = op.outputs[0].op cond_graph = _get_graph(while_op, "cond") body_graph = _get_graph(while_op, "body") orig_num_params = len(body_graph.outputs) maximum_iterations = op.get_attr( "_maximum_iterations") if _is_in_xla_context() else None assert not _is_in_xla_context() or maximum_iterations is not None maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations) # Set the incoming gradient of non-trainable inputs to None. It is possible # that we receive non-None gradients for non-trainable types in nested while # loops because we accumulate outputs of the inner while as variant tensors # which are trainable and hence receive zeros_like tensors in the gradient # pass. The non-trainable tensors then receive the popped zeros tensor from # this zeros variant. The gradient for the loop vars corresponding to these # tensors is None or zeros (this happens only if the loop var is accumulated # as well) in _grad_fn so we reset these. # TODO(b/118712257): Remove the IsTrainable filter once we can handle None # output grads in _grad_fn. grads = [ None if not _is_trainable(output) else grad for grad, output in zip(grads, body_graph.outputs) ] # We compute the gradient for the sub-graph between trainable ys and xs # with non-None incoming gradients. We later pad the None's to the list of # outputs. ys, xs, non_none_grads = zip( *[(y, x, grad) for (y, x, grad) in zip(body_graph.outputs, body_graph.inputs, grads) if grad is not None]) body_grad_graph, args = _create_grad_func( ys, xs, non_none_grads, cond_graph, body_graph, util.unique_grad_fn_name(body_graph.name), op, maximum_iterations) if body_grad_graph.while_op_needs_rewrite: # Modify 'op' to output the intermediate accumulators needed by the grad # function. # NOTE(skyewm): if there are any active sessions, this modification to `op` # may make them unrunnable! cond_graph.name += "_rewritten" body_graph.name += "_rewritten" new_inputs = body_grad_graph.empty_tensor_lists new_outputs = body_graph.outputs[orig_num_params:] while_op._set_func_attr("cond", util.create_new_tf_function(cond_graph)) while_op._set_func_attr("body", util.create_new_tf_function(body_graph)) while_op._set_type_list_attr("T", body_graph.output_types) while_op._set_shape_list_attr("output_shapes", body_graph.output_shapes) while_op._add_while_inputs(new_inputs) while_op._add_outputs([t.dtype for t in new_outputs], [t.shape for t in new_outputs]) _copy_handle_data(new_outputs, op.outputs[orig_num_params:]) captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph, while_op) loop_vars = args + captured_inputs def grad_cond(counter, max_iters, *unused_args): return counter < max_iters grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name) cond_grad_graph = func_graph_module.func_graph_from_py_func( grad_cond_name, grad_cond, loop_vars, {}, func_graph=util.WhileCondFuncGraph(grad_cond_name)) _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars)) outputs = gen_functional_ops._while( loop_vars, util.create_new_tf_function(cond_grad_graph), util.create_new_tf_function(body_grad_graph), output_shapes=[t.shape for t in body_grad_graph.outputs], name="%s_grad" % while_op.name) _copy_handle_data(body_grad_graph.outputs, outputs) util.maybe_set_lowering_attr(outputs[0].op) _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations) # See comment in while_loop. outputs = [array_ops.identity(t) for t in outputs] # Set None as the output gradient for tensors with None input gradient. # outputs[0] is the loop counter. # outputs[1] is the total number of loop iterations. index = 2 none_padded_outputs = [] for g in grads: if g is None: none_padded_outputs.append(None) else: none_padded_outputs.append(outputs[index]) index += 1 return none_padded_outputs
def _create_grad_func(func_graph, grads, name): """Returns the FuncGraph representation of _grad_fn.""" return func_graph_module.func_graph_from_py_func( name, lambda: _grad_fn(func_graph, grads), [], {}, func_graph=util.CondBranchFuncGraph(name, read_only_collections=False))
def _WhileGrad(op, *grads): # pylint: disable=invalid-name """The gradient of a While op produced by while_loop.""" body_graph = _get_body_graph(op) # Set the incoming gradient of TensorArray handle to None. # TODO(b/118164915): We need a way of distinguising b/w TensorArray resource # handles and ResourceVariables and set the default gradient of only the # TensorArray handle to None. grads = [ None if output.dtype == dtypes.resource else g for g, output in zip(grads, op.outputs) ] # Ensure that all non-resource trainable outputs have incoming gradients. assert all(g is not None or o.dtype == dtypes.resource or not gradients_impl.IsTrainable(o) for o, g in zip(op.outputs, grads) ), "All trainable loop vars must receive incoming gradients." # We compute the gradient for the sub-graph between trainable ys and xs # with non-None incoming gradients. We later pad the None's to the list of # outputs. ys, xs, non_none_grads = zip( *[(y, x, grad) for (y, x, grad) in zip(body_graph.outputs, body_graph.inputs, grads) if grad is not None]) body_grad_graph, args = _create_grad_func( ys, xs, non_none_grads, body_graph, util.unique_grad_fn_name(body_graph.name), op) intermediate_tensors = _get_intermediates(body_grad_graph) for intermediate_tensor in intermediate_tensors: tensor_list = list_ops.empty_tensor_list( element_dtype=intermediate_tensor.dtype, element_shape=_get_tensor_convertible_shape( intermediate_tensor.shape)) with body_grad_graph.as_default(): tensor_list_ph = body_grad_graph.capture(tensor_list, whitelisted=True) # Push the intermediate tensor to the tensor list. appended_tensor_list = list_ops.tensor_list_push_back( tensor_list_ph, intermediate_tensor) # Add this modified tensor list to the list of outputs. body_grad_graph.outputs.append(appended_tensor_list) def grad_cond(counter, max_iters, *unused_args): return counter < max_iters loop_vars = args + body_grad_graph.external_captures grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name) cond_grad_graph = func_graph_module.func_graph_from_py_func( grad_cond_name, grad_cond, loop_vars, {}, func_graph=util.WhileCondFuncGraph(grad_cond_name)) _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars)) outputs = gen_functional_ops._while( loop_vars, util.create_new_tf_function(cond_grad_graph), util.create_new_tf_function(body_grad_graph), output_shapes=[t.shape for t in body_grad_graph.outputs], name="%s_grad" % op.name) _copy_handle_data(body_grad_graph.outputs, outputs) _maybe_set_lowering_attr(outputs[0].op) # Set None as the output gradient for tensors with None input gradient # e.g. TensorArray handles. # outputs[0] is the loop counter. # outputs[1] is the total number of loop iterations. index = 2 none_padded_outputs = [] for g in grads: if g is None: none_padded_outputs.append(None) else: none_padded_outputs.append(outputs[index]) index += 1 return none_padded_outputs
def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" if isinstance(pred, bool): raise TypeError("pred must not be a Python bool", pred) if not name: name = "cond" with ops.name_scope(name) as scope: true_name = util.unique_fn_name(scope, "true") false_name = util.unique_fn_name(scope, "false") # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = util.in_defun() pred = ops.convert_to_tensor(pred) true_graph = func_graph_module.func_graph_from_py_func( true_name, true_fn, [], {}, func_graph=util.CondBranchFuncGraph( true_name, read_only_collections=False), add_control_dependencies=add_control_dependencies, op_return_value=pred) false_graph = func_graph_module.func_graph_from_py_func( false_name, false_fn, [], {}, func_graph=util.CondBranchFuncGraph( false_name, read_only_collections=False), add_control_dependencies=add_control_dependencies, op_return_value=pred) _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_graph.external_captures, false_graph.external_captures) # Add all intermediate tensors as function outputs so they're available for # the gradient computation. true_intermediates = _get_intermediates(true_graph) false_intermediates = _get_intermediates(false_graph) # Save the original number of outputs to return to the caller. num_cond_outputs = len(true_graph.outputs) # Make the number/type of new intermediate outputs match. extra_true_outputs, extra_false_outputs = _pad_params( true_graph, false_graph, true_intermediates, false_intermediates) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # Create the If op. tensors = gen_functional_ops._if( # pylint: disable=protected-access pred, cond_inputs, [t.dtype for t in true_graph.outputs], util.create_new_tf_function(true_graph), util.create_new_tf_function(false_graph), output_shapes=_get_output_shapes(true_graph.outputs, false_graph.outputs), name=scope) # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output util.maybe_set_lowering_attr(tensors[0].op) # Return identities for each output of the If op, rather than the output of # the If op directly. This makes pruning work if the output of cond() is # fetched: the lowering pass converts the If outputs into IdentityN outputs, # which if fetched will cause all ops in the taken branch to be run (since # it takes all merge ops as input). After lowering, each output identity op # will end up with only the appropriate merge op as input. # TODO(b/79984175): this doesn't have to be a tuple once we covert to the # correct output structure tensors = tuple(array_ops.identity(t) for t in tensors) return func_graph_module.pack_sequence_as(true_graph.structured_outputs, tensors[:num_cond_outputs])
def wrap_function(fn, signature, name=None): """Wraps the TF 1.x function fn into a graph function. The python function `fn` will be called once with symbolic arguments specified in the `signature`, traced, and turned into a graph function. Any variables created by `fn` will be owned by the object returned by `wrap_function`. The resulting graph function can be called with tensors which match the signature. ```python def f(x, do_add): v = tf.Variable(5.0) if do_add: op = v.assign_add(x) else: op = v.assign_sub(x) with tf.control_dependencies([op]): return v.read_value() f_add = tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), True]) assert float(f_add(1.0)) == 6.0 assert float(f_add(1.0)) == 7.0 # Can call tf.compat.v1.wrap_function again to get a new trace, a new set # of variables, and possibly different non-template arguments. f_sub= tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), False]) assert float(f_sub(1.0)) == 4.0 assert float(f_sub(1.0)) == 3.0 ``` Both `tf.compat.v1.wrap_function` and `tf.function` create a callable TensorFlow graph. But while `tf.function` runs all stateful operations (e.g. `tf.print`) and sequences operations to provide the same semantics as eager execution, `wrap_function` is closer to the behavior of `session.run` in TensorFlow 1.x. It will not run any operations unless they are required to compute the function's outputs, either through a data dependency or a control dependency. Nor will it sequence operations. Unlike `tf.function`, `wrap_function` will only trace the Python function once. As with placeholders in TF 1.x, shapes and dtypes must be provided to `wrap_function`'s `signature` argument. Since it is only traced once, variables and state may be created inside the function and owned by the function wrapper object. Args: fn: python function to be wrapped signature: the placeholder and python arguments to be passed to the wrapped function name: Optional. The name of the function. Returns: the wrapped graph function. """ holder = VariableHolder(fn) func_graph_name = "wrapped_function" if name is not None: func_graph_name = "wrapped_function_" + name return WrappedFunction( func_graph.func_graph_from_py_func( func_graph_name, holder, args=None, kwargs=None, signature=signature, add_control_dependencies=False, collections={}), variable_holder=holder, signature=signature)