def _compute_backprop(self): """Computes the backprop function object for this function.""" self._has_backprop = True with self._graph.as_default(), context.graph_mode(): c = _CapturingContext() with c: filtered_outputs = [ x for x in self._returns if x is not None ] self._out_grad_placeholders = [ graph_placeholder(x.dtype, x.shape) for x in filtered_outputs ] in_gradients = gradients_impl.gradients( filtered_outputs, self._input_placeholders, grad_ys=self._out_grad_placeholders) shapes = [x.shape for x in in_gradients if x is not None] captures = list(sorted(c.captured_tensors, key=lambda x: x.name)) forward_function_def = graph_to_function_def.graph_to_function_def( self._graph, self._ops, self._input_placeholders, filtered_outputs + captures) self._forward_fdef = _DefinedFunction(forward_function_def) _register_with_name(_forward_name(self._func_name), forward_function_def) backward_outputs = [x for x in in_gradients if x is not None] all_inputs = self._out_grad_placeholders + captures backward_function_def = graph_to_function_def.graph_to_function_def( self._graph, [x.op for x in self._out_grad_placeholders ] + list(sorted(c.known_ops, key=lambda x: x.name)), all_inputs, backward_outputs) _register_with_name(_backward_name(self._func_name), backward_function_def) self._backward_function = _GraphModeFunction( all_inputs, [], backward_function_def, self._graph, c.known_ops, in_gradients, _map_sequence_obj_to_idx(backward_outputs), shapes)
def _compute_backprop(self): """Computes the backprop function object for this function.""" self._has_backprop = True with self._graph.as_default(), context.graph_mode(): c = _CapturingContext() with c: filtered_outputs = [x for x in self._returns if x is not None] self._out_grad_placeholders = [ graph_placeholder(x.dtype, x.shape) for x in filtered_outputs ] in_gradients = gradients_impl.gradients( filtered_outputs, self._input_placeholders, grad_ys=self._out_grad_placeholders) shapes = [x.shape for x in in_gradients if x is not None] captures = list(sorted(c.captured_tensors, key=lambda x: x.name)) forward_function_def = graph_to_function_def.graph_to_function_def( self._graph, self._ops, self._input_placeholders, filtered_outputs + captures) self._forward_fdef = _DefinedFunction(forward_function_def) _register_with_name(_forward_name(self._func_name), forward_function_def) backward_outputs = [x for x in in_gradients if x is not None] all_inputs = self._out_grad_placeholders + captures backward_function_def = graph_to_function_def.graph_to_function_def( self._graph, [x.op for x in self._out_grad_placeholders] + list(sorted(c.known_ops, key=lambda x: x.name)), all_inputs, backward_outputs) _register_with_name(_backward_name(self._func_name), backward_function_def) self._backward_function = _GraphModeFunction( all_inputs, [], backward_function_def, self._graph, c.known_ops, in_gradients, _map_sequence_obj_to_idx(backward_outputs), shapes)
def _create_definition_if_needed(self): """Creates the function definition if it's not created yet.""" if self._definition is not None: return # Create the func_def object. temp_graph = _FuncGraph() with temp_graph.as_default(): # List of placeholders for the function_def. inputs = [] for (argname, argtype) in self._args: argholder = array_ops.placeholder(argtype, name=argname) inputs.append(argholder) # Call func and gather the output tensors. with vs.variable_scope("", custom_getter=temp_graph.getvar): outputs = self._func(*inputs) # If func only returned one value, make it a tuple. if not isinstance(outputs, (list, tuple)): outputs = (outputs, ) if any([_ is None for _ in outputs]): raise ValueError("Function can not return None.") # Ensures each output is a Tensor. outputs = [ops.convert_to_tensor(_) for _ in outputs] self._extra_inputs = temp_graph.extra_inputs inputs.extend(temp_graph.extra_args) # pylint: disable=protected-access self._sub_functions = temp_graph._functions # pylint: enable=protected-access # Build the FunctionDef self._definition = graph_to_function_def.graph_to_function_def( temp_graph, temp_graph.get_operations(), inputs, outputs, out_names=self._out_names) # Extra kwargs are treated as attrs on the function def. sig_pre_func_name = self._func_name or _get_func_name(self._func) kwargs_attr = _parse_kwargs_as_attrs(sig_pre_func_name, **self._extra_kwargs) for k in kwargs_attr: self._definition.attr[k].CopyFrom(kwargs_attr[k]) # Hash the definition and its dependencies. self._hash_str = self._create_hash_str( self._definition.signature.input_arg, self._definition.signature.output_arg, self._definition.node_def) # Finally, we decide the function name to use. If not specified, # make up something which is almost certainly unique (but deterministic). if not self._func_name: self._func_name = "_".join( [_get_func_name(self._func), self._hash_str]) self._definition.signature.name = self._func_name if self._func.__doc__: self._definition.signature.description = self._func.__doc__
def _create_definition_if_needed(self): """Creates the function definition if it's not created yet.""" if self._definition is not None: return # Create the func_def object. temp_graph = _FuncGraph() with temp_graph.as_default(): # List of placeholders for the function_def. inputs = [] for (argname, argtype) in self._args: argholder = array_ops.placeholder(argtype, name=argname) inputs.append(argholder) # Call func and gather the output tensors. with vs.variable_scope("", custom_getter=temp_graph.getvar): outputs = self._func(*inputs) # If func only returned one value, make it a tuple. if not isinstance(outputs, (list, tuple)): outputs = (outputs,) if any([_ is None for _ in outputs]): raise ValueError("Function can not return None.") # Ensures each output is a Tensor. outputs = [ops.convert_to_tensor(_) for _ in outputs] self._extra_inputs = temp_graph.extra_inputs inputs.extend(temp_graph.extra_args) # pylint: disable=protected-access self._sub_functions = temp_graph._functions # pylint: enable=protected-access # Build the FunctionDef self._definition = graph_to_function_def.graph_to_function_def( temp_graph, temp_graph.get_operations(), inputs, outputs, out_names=self._out_names) # Extra kwargs are treated as attrs on the function def. sig_pre_func_name = self._func_name or _get_func_name(self._func) kwargs_attr = _parse_kwargs_as_attrs(sig_pre_func_name, **self._extra_kwargs) for k in kwargs_attr: self._definition.attr[k].CopyFrom(kwargs_attr[k]) # Hash the definition and its dependencies. self._hash_str = self._create_hash_str( self._definition.signature.input_arg, self._definition.signature.output_arg, self._definition.node_def) # Finally, we decide the function name to use. If not specified, # make up something which is almost certainly unique (but deterministic). if not self._func_name: self._func_name = "_".join([_get_func_name(self._func), self._hash_str]) self._definition.signature.name = self._func_name if self._func.__doc__: self._definition.signature.description = self._func.__doc__
def _build_function_def(self): with ops.Graph().as_default() as g: # Inputs: x y z # |\ | / # | \ | / # | foo_1 list_output # | / \ / \ # | d_1 e_1 a:1 a:0 # | \ | / | # | \ | / | # | foo_2 | # | / \ | # Outputs: x d_2 e_2 a:0 x = array_ops.placeholder(dtypes.float32, name="x") y = array_ops.placeholder(dtypes.int32, name="y") z = array_ops.placeholder(dtypes.int32, name="z") d_1, e_1 = test_ops._op_def_lib.apply_op("Foo1", name="foo_1", a=x, b=y, c=z) list_output0, list_output1 = test_ops.list_output( T=[dtypes.int32, dtypes.int32], name="list_output") d_2, e_2 = test_ops.foo1(a=d_1, b=e_1, c=list_output1, name="foo_2") fdef = graph_to_function_def.graph_to_function_def( g, g.get_operations(), [x, y, z], # Inputs [x, d_2, e_2, list_output0]) # Outputs. # Assert that the FunctionDef was correctly built. assert len(fdef.node_def) == 3 # 2 Foo1 nodes and 1 ListOutput node. assert fdef.node_def[0].op == "Foo1" assert fdef.node_def[0].input == ["x", "y", "z"] assert fdef.node_def[1].op == "ListOutput" assert not fdef.node_def[1].input assert fdef.node_def[2].op == "Foo1" assert fdef.node_def[2].input == [ "foo_1:d:0", "foo_1:e:0", "list_output:a:1" ] return fdef
def _defun_internal(name, func, args, kwds): """Defines and returns graph-mode version of func.""" with context.graph_mode(): tmp_graph = ops.Graph() # Copy the graph collections to ensure summaries and other things work. This # lets the function access (but not mutate) collections of the containing # graph, such as the global step and the summary writer collections. curr_graph = ops.get_default_graph() for collection in curr_graph.collections: tmp_graph.get_collection_ref( collection)[:] = curr_graph.get_collection(collection) with tmp_graph.as_default(): func_inputs = _get_defun_inputs(args) captures = {} with capture_tensors(captures): func_outputs = func(*func_inputs, **kwds) ids = list(sorted(captures.keys())) if ids: extra_inputs, extra_placeholders = zip( *[captures[x] for x in ids]) else: extra_inputs = [] extra_placeholders = [] outputs_list = nest.flatten(func_outputs) output_shapes = [x.shape for x in outputs_list if x is not None] flat_inputs = [ x for x in nest.flatten(func_inputs) if isinstance(x, ops.Tensor) ] all_inputs = flat_inputs + list(extra_placeholders) func_def_outputs = [ ag_core.getval(x) for x in outputs_list if x is not None ] inference_function_def = graph_to_function_def.graph_to_function_def( tmp_graph, tmp_graph.get_operations(), all_inputs, func_def_outputs) # Register any other functions defined in the graph # TODO(ashankar): Oh lord, forgive me for this lint travesty. for f in tmp_graph._functions.values(): # pylint: disable=protected-access # TODO(ashankar): What about the gradient registry? _register_with_name(f.name, f.definition) _register_with_name(_inference_name(name), inference_function_def) return _GraphModeFunction(all_inputs, extra_inputs, inference_function_def, tmp_graph, tmp_graph.get_operations(), func_outputs, _map_sequence_obj_to_idx(func_def_outputs), output_shapes)
def testTwoInputsSameOp(self): g = ops.Graph() with g.as_default(): m = array_ops.placeholder(dtypes.float32) s, u, v = linalg_ops.svd(m) ss = math_ops.reduce_sum(s) uu = math_ops.reduce_sum(u) vv = math_ops.reduce_sum(v) result = ss + uu + vv f = graph_to_function_def.graph_to_function_def( g, g.get_operations()[1:], # skip the placeholder [s, u, v], [result]) self.assertEqual(len(f.signature.input_arg), 3)
def testTwoInputsSameOp(self): g = ops.Graph() with g.as_default(): m = array_ops.placeholder(dtypes.float32) s, u, v = linalg_ops.svd(m) ss = math_ops.reduce_sum(s) uu = math_ops.reduce_sum(u) vv = math_ops.reduce_sum(v) result = ss + uu + vv f = graph_to_function_def.graph_to_function_def( g, g.get_operations()[1:], # skip the placeholder [s, u, v], [result]) self.assertEqual(len(f.signature.input_arg), 3)
def _defun_internal(name, func, args, kwds): """Defines and returns graph-mode version of func.""" with context.graph_mode(): tmp_graph = ops.Graph() # Copy the graph collections to ensure summaries and other things work. This # lets the function access (but not mutate) collections of the containing # graph, such as the global step and the summary writer collections. curr_graph = ops.get_default_graph() for collection in curr_graph.collections: tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection( collection) with tmp_graph.as_default(): func_inputs = _get_defun_inputs(args) captures = {} with capture_tensors(captures): func_outputs = func(*func_inputs, **kwds) ids = list(sorted(captures.keys())) if ids: extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids]) else: extra_inputs = [] extra_placeholders = [] outputs_list = nest.flatten(func_outputs) output_shapes = [x.shape for x in outputs_list if x is not None] flat_inputs = [ x for x in nest.flatten(func_inputs) if isinstance(x, ops.Tensor) ] all_inputs = flat_inputs + list(extra_placeholders) func_def_outputs = [x for x in outputs_list if x is not None] inference_function_def = graph_to_function_def.graph_to_function_def( tmp_graph, tmp_graph.get_operations(), all_inputs, func_def_outputs) # Register any other functions defined in the graph # TODO(ashankar): Oh lord, forgive me for this lint travesty. for f in tmp_graph._functions.values(): # pylint: disable=protected-access # TODO(ashankar): What about the gradient registry? _register_with_name(f.name, f.definition) _register_with_name(_inference_name(name), inference_function_def) return _GraphModeFunction( all_inputs, extra_inputs, inference_function_def, tmp_graph, tmp_graph.get_operations(), func_outputs, _map_sequence_obj_to_idx(func_def_outputs), output_shapes)
def _build_function_def(self): with ops.Graph().as_default() as g: # Inputs x = array_ops.placeholder(dtypes.float32, name="x") y = array_ops.placeholder(dtypes.float32, name="y") # Outputs sum_squares = math_ops.add_n( [math_ops.pow(x, 2), math_ops.pow(y, 2)], name="sum_squares") sum_cubes = math_ops.add_n( [math_ops.pow(x, 3), math_ops.pow(y, 3)], name="sum_cubes") fdef = graph_to_function_def.graph_to_function_def( g, g.get_operations(), [x, y], # Inputs [sum_squares, sum_cubes]) # Outputs. fdef.signature.name = "_whats_in_a_name" return fdef
def _build_function_def(self): with ops.Graph().as_default() as g: # Inputs x = array_ops.placeholder(dtypes.float32, name="x") y = array_ops.placeholder(dtypes.float32, name="y") # Outputs sum_squares = math_ops.add_n( [math_ops.pow(x, 2), math_ops.pow(y, 2)], name="sum_squares") sum_cubes = math_ops.add_n( [math_ops.pow(x, 3), math_ops.pow(y, 3)], name="sum_cubes") fdef = graph_to_function_def.graph_to_function_def( g, g.get_operations(), [x, y], # Inputs [sum_squares, sum_cubes]) # Outputs. fdef.signature.name = "_whats_in_a_name" return fdef
def _build_function_def(self): with ops.Graph().as_default() as g: # Inputs: x y z # |\ | / # | \ | / # | foo_1 list_output # | / \ / \ # | d_1 e_1 a:1 a:0 # | \ | / | # | \ | / | # | foo_2 | # | / \ | # Outputs: x d_2 e_2 a:0 x = array_ops.placeholder(dtypes.float32, name="x") y = array_ops.placeholder(dtypes.int32, name="y") z = array_ops.placeholder(dtypes.int32, name="z") d_1, e_1 = test_ops._op_def_lib.apply_op( "Foo1", name="foo_1", a=x, b=y, c=z) list_output0, list_output1 = test_ops.list_output( T=[dtypes.int32, dtypes.int32], name="list_output") d_2, e_2 = test_ops.foo1(a=d_1, b=e_1, c=list_output1, name="foo_2") fdef = graph_to_function_def.graph_to_function_def( g, g.get_operations(), [x, y, z], # Inputs [x, d_2, e_2, list_output0]) # Outputs. # Assert that the FunctionDef was correctly built. assert len(fdef.node_def) == 3 # 2 Foo1 nodes and 1 ListOutput node. assert fdef.node_def[0].op == "Foo1" assert fdef.node_def[0].input == ["x", "y", "z"] assert fdef.node_def[1].op == "ListOutput" assert not fdef.node_def[1].input assert fdef.node_def[2].op == "Foo1" assert fdef.node_def[2].input == [ "foo_1:d:0", "foo_1:e:0", "list_output:a:1" ] return fdef
def make_function_def(graph, operations, inputs, outputs): """Makes function def where accesses to resources are serialized.""" last_op_using_resource_tensor = {} # TODO(apassos) probably control flow has to be handled delicately here as in # if a resource is accessed inside a control flow context we need the control # dependency to point to something outside the context which is guaranteed to # happen after the access. # # TODO(apassos) this should do some form of alias analysis as ops which # forward the resources such as Identity and Switch can cause serialization to # fail. for op in operations: for t in op.inputs: if t.dtype == dtypes.resource: if t.name in last_op_using_resource_tensor: op._add_control_input(last_op_using_resource_tensor[t.name]) # pylint: disable=protected-access last_op_using_resource_tensor[t.name] = op return graph_to_function_def.graph_to_function_def( graph, operations, inputs, outputs)
def make_function_def(graph, operations, inputs, outputs): """Makes function def where accesses to resources are serialized.""" last_op_using_resource_tensor = {} # TODO(apassos) probably control flow has to be handled delicately here as in # if a resource is accessed inside a control flow context we need the control # dependency to point to something outside the context which is guaranteed to # happen after the access. # # TODO(apassos) this should do some form of alias analysis as ops which # forward the resources such as Identity and Switch can cause serialization to # fail. for op in operations: for t in op.inputs: if t.dtype == dtypes.resource: if t.name in last_op_using_resource_tensor: op._add_control_input(last_op_using_resource_tensor[t.name]) # pylint: disable=protected-access last_op_using_resource_tensor[t.name] = op return graph_to_function_def.graph_to_function_def( graph, operations, inputs, outputs)
def _create_definition_if_needed_impl(self): """This is not what you want, see _create_definition_if_needed.""" if self._definition is not None or self._c_func is not None: return # Copy variable collections (by reference) from the parent graph such that # name based variable sharing (e.g. via tf.make_template) works between the # func graph and parent graph. variable_keys = [] variable_keys.extend(ops.GraphKeys._VARIABLE_COLLECTIONS) # pylint: disable=protected-access variable_keys.append(vs._VARSTORE_KEY) # pylint: disable=protected-access parent_graph = ops.get_default_graph() collections_ref = { key: parent_graph.get_collection_ref(key) for key in variable_keys} temp_graph = func_graph_from_py_func( self._func, self._arg_names, self._arg_types, self._func_name, self._capture_by_value, self._caller_device, collections_ref=collections_ref, allowlisted_stateful_ops=self._allowlisted_stateful_ops, capture_resource_var_by_value=self._capture_resource_var_by_value) self._extra_inputs = temp_graph.extra_inputs # pylint: disable=protected-access self._sub_functions = temp_graph._functions # pylint: enable=protected-access # Extra kwargs are treated as attrs on the function def. if self._func_name: base_func_name = self._func_name else: base_func_name = function_utils.get_func_name(self._func) if self._grad_func: base_func_name += ("_%s" % self._grad_func.name) kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs) if not temp_graph._c_graph: # pylint: disable=protected-access # Build the FunctionDef self._definition = graph_to_function_def.graph_to_function_def( temp_graph, temp_graph.get_operations(), temp_graph.inputs, temp_graph.outputs, out_names=self._out_names) for k in kwargs_attr: self._definition.attr[k].CopyFrom(kwargs_attr[k]) # Hash the definition and its dependencies. self._hash_str = self._create_hash_str( self._definition.signature.input_arg, self._definition.signature.output_arg, self._definition.node_def) # Finally, we decide the function name to use. If not specified, # make up something which is almost certainly unique (but deterministic). if not self._func_name: self._func_name = "_".join([base_func_name, self._hash_str]) self._definition.signature.name = self._func_name if self._func.__doc__: self._definition.signature.description = self._func.__doc__ self._op_def = self._definition.signature else: # C API is enabled output_names = ([compat.as_bytes(x) for x in self._out_names] if self._out_names else []) description = self._func.__doc__ or None # pylint: disable=protected-access c_func = c_api.TF_GraphToFunction_wrapper( temp_graph._c_graph, base_func_name, self._func_name is None, # append_hash_to_fn_name None, # opers [t._as_tf_output() for t in temp_graph.inputs], [t._as_tf_output() for t in temp_graph.outputs], output_names, [], # control_outputs [], # control_output_names None, # opts description) self._c_func = c_api_util.ScopedTFFunction(c_func) # pylint: enable=protected-access self._set_c_attrs(kwargs_attr) # Set cached fields: _op_def and _func_name (if not already set) self._op_def = self.definition.signature if self._func_name: assert self._func_name == self._op_def.name else: self._func_name = compat.as_str(self._op_def.name) self._stateful_ops = [(op.name, op.type) for op in temp_graph.get_operations() if op._is_stateful] # pylint: disable=protected-access
def _create_definition_if_needed_impl(self): """This is not what you want, see _create_definition_if_needed.""" if self._definition is not None or self._c_func is not None: return temp_graph = func_graph_from_py_func( self._func, self._arg_names, self._arg_types, self._func_name, self._capture_by_value, self._caller_device, whitelisted_stateful_ops=self._whitelisted_stateful_ops) self._extra_inputs = temp_graph.extra_inputs # pylint: disable=protected-access self._sub_functions = temp_graph._functions # pylint: enable=protected-access # Extra kwargs are treated as attrs on the function def. if self._func_name: base_func_name = self._func_name else: base_func_name = function_utils.get_func_name(self._func) if self._grad_func: base_func_name += ("_%s" % self._grad_func.name) kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs) if not temp_graph._c_graph: # pylint: disable=protected-access # Build the FunctionDef self._definition = graph_to_function_def.graph_to_function_def( temp_graph, temp_graph.get_operations(), temp_graph.inputs, temp_graph.outputs, out_names=self._out_names) for k in kwargs_attr: self._definition.attr[k].CopyFrom(kwargs_attr[k]) # Hash the definition and its dependencies. self._hash_str = self._create_hash_str( self._definition.signature.input_arg, self._definition.signature.output_arg, self._definition.node_def) # Finally, we decide the function name to use. If not specified, # make up something which is almost certainly unique (but deterministic). if not self._func_name: self._func_name = "_".join([base_func_name, self._hash_str]) self._definition.signature.name = self._func_name if self._func.__doc__: self._definition.signature.description = self._func.__doc__ self._op_def = self._definition.signature else: # C API is enabled output_names = ([compat.as_bytes(x) for x in self._out_names] if self._out_names else []) description = self._func.__doc__ or None # pylint: disable=protected-access c_func = c_api.TF_GraphToFunction_wrapper( temp_graph._c_graph, base_func_name, self._func_name is None, # append_hash_to_fn_name None, # opers [t._as_tf_output() for t in temp_graph.inputs], [t._as_tf_output() for t in temp_graph.outputs], output_names, None, # opts description) self._c_func = c_api_util.ScopedTFFunction(c_func) # pylint: enable=protected-access self._set_c_attrs(kwargs_attr) # Set cached fields: _op_def and _func_name (if not already set) self._op_def = self.definition.signature if self._func_name: assert self._func_name == self._op_def.name else: self._func_name = compat.as_str(self._op_def.name) self._stateful_ops = [(op.name, op.type) for op in temp_graph.get_operations() if op.op_def.is_stateful]
def _graph_callable_internal(func, shape_and_dtypes): """Defines and returns a template version of func. Under the hood we make two function objects, each wrapping a different version of the graph-mode code. One version immediately runs variable initialization before making the variable's Tensors available for use, while the other version replaces the Variables with placeholders which become function arguments and get the current variable's value. Limitations in (2) and (4) are because this does not implement a graph-mode Variable class which has a convert_to_tensor(as_ref=True) method and a initialized_value method. This is fixable. Args: func: The tfe Python function to compile. shape_and_dtypes: A list of type ShapeAndDtype. Raises: ValueError: If any one of func's outputs is not a Tensor. Returns: Callable graph object. """ with context.graph_mode(): # This graph will store both the initialization and the call version of the # wrapped function. It will later be used by the backprop code to build the # backprop graph, if necessary. tmp_graph = tf_ops.Graph() with tmp_graph.as_default(): # Placeholders for the non-variable inputs. func_inputs = _get_graph_callable_inputs(shape_and_dtypes) func_num_args = len(tf_inspect.getargspec(func).args) if len(func_inputs) != func_num_args: raise TypeError( "The number of arguments accepted by the decorated " "function `%s` (%d) must match the number of " "ShapeAndDtype objects passed to the graph_callable() " "decorator (%d)." % (func.__name__, func_num_args, len(func_inputs))) # First call the function to generate a graph which can initialize all # variables. As a side-effect this will populate the variable capturing # scope's view of which variables exist. variable_captures = _VariableCapturingScope() captures = {} with variable_captures.initializing_scope( ), function.capture_tensors(captures): func_outputs = func(*func_inputs) outputs_list = nest.flatten(func_outputs) output_shapes = [x.shape for x in outputs_list if x is not None] if not all(isinstance(x, tf_ops.Tensor) for x in outputs_list): raise ValueError("Found non-tensor output in %s" % str(outputs_list)) initializing_operations = tmp_graph.get_operations() # Call the function again, now replacing usages of variables with # placeholders. This assumes the variable capturing scope created above # knows about all variables. with variable_captures.capturing_scope(), function.capture_tensors( captures): captured_outputs = func(*func_inputs) captured_outlist = nest.flatten(captured_outputs) capturing_operations = tmp_graph.get_operations( )[len(initializing_operations):] sorted_variables = sorted(variable_captures.variables.values(), key=lambda x: x.name) variable_placeholders = [x.placeholder for x in sorted_variables] ids = list(sorted(captures.keys())) if ids: extra_inputs, extra_placeholders = zip(*[captures[x] for x in ids]) else: extra_inputs = [] extra_placeholders = [] flat_inputs = [ x for x in nest.flatten(func_inputs) if isinstance(x, tf_ops.Tensor) ] placeholder_inputs = flat_inputs + list(extra_placeholders) all_inputs = variable_placeholders + placeholder_inputs func_def_outputs = [ x for x in outputs_list if isinstance(x, tf_ops.Tensor) ] initializer_function_def = graph_to_function_def.graph_to_function_def( tmp_graph, initializing_operations, placeholder_inputs, func_def_outputs) # TODO(ashankar): Oh lord, forgive me for this lint travesty. # Also, what about the gradient registry of these functions? Those need to be # addressed as well. for f in tmp_graph._functions.values(): # pylint: disable=protected-access function._register_with_name(f.name, f.definition) # pylint: disable=protected-access function._register_with_name( function._inference_name(func.__name__), # pylint: disable=protected-access initializer_function_def) initializer_function = function._GraphModeFunction( # pylint: disable=protected-access placeholder_inputs, extra_inputs, initializer_function_def, tmp_graph, initializing_operations, func_outputs, function._map_sequence_obj_to_idx(func_def_outputs), # pylint: disable=protected-access output_shapes) capture_func_def_outputs = [ x for x in captured_outlist if isinstance(x, tf_ops.Tensor) ] captured_function_def = graph_to_function_def.graph_to_function_def( tmp_graph, capturing_operations, all_inputs, capture_func_def_outputs) function._register_with_name( function._inference_name(func.__name__), # pylint: disable=protected-access captured_function_def) captured_function = _FunctionObject( sorted_variables, all_inputs, extra_inputs, captured_function_def, tmp_graph, capturing_operations, captured_outputs, function._map_sequence_obj_to_idx(capture_func_def_outputs), # pylint: disable=protected-access output_shapes) return _InitializingFunctionObject(captured_function, initializer_function)
def _create_definition_if_needed_impl(self): """This is not what you want, see _create_definition_if_needed.""" if self._definition is not None or self._c_func is not None: return temp_graph = func_graph_from_py_func( self._func, self._arg_names, self._arg_types, self._func_name, self._capture_by_value, self._caller_device) self._extra_inputs = temp_graph.extra_inputs # pylint: disable=protected-access self._sub_functions = temp_graph._functions # pylint: enable=protected-access # Extra kwargs are treated as attrs on the function def. if self._func_name: base_func_name = self._func_name else: base_func_name = _get_func_name(self._func) if self._grad_func: base_func_name += ("_%s" % self._grad_func.name) kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs) if not temp_graph._c_graph: # pylint: disable=protected-access # Build the FunctionDef self._definition = graph_to_function_def.graph_to_function_def( temp_graph, temp_graph.get_operations(), temp_graph.inputs, temp_graph.outputs, out_names=self._out_names) for k in kwargs_attr: self._definition.attr[k].CopyFrom(kwargs_attr[k]) # Hash the definition and its dependencies. self._hash_str = self._create_hash_str( self._definition.signature.input_arg, self._definition.signature.output_arg, self._definition.node_def) # Finally, we decide the function name to use. If not specified, # make up something which is almost certainly unique (but deterministic). if not self._func_name: self._func_name = "_".join([base_func_name, self._hash_str]) self._definition.signature.name = self._func_name if self._func.__doc__: self._definition.signature.description = self._func.__doc__ self._op_def = self._definition.signature else: # C API is enabled output_names = ([compat.as_bytes(x) for x in self._out_names] if self._out_names else []) description = self._func.__doc__ or None # pylint: disable=protected-access c_func = c_api.TF_GraphToFunction_wrapper( temp_graph._c_graph, base_func_name, self._func_name is None, # append_hash_to_fn_name None, # opers [t._as_tf_output() for t in temp_graph.inputs], [t._as_tf_output() for t in temp_graph.outputs], output_names, None, # opts description) self._c_func = c_api_util.ScopedTFFunction(c_func) # pylint: enable=protected-access self._set_c_attrs(kwargs_attr) # Set cached fields: _op_def and _func_name (if not already set) self._op_def = self.definition.signature if self._func_name: assert self._func_name == self._op_def.name else: self._func_name = compat.as_str(self._op_def.name) self._stateful_ops = [(op.name, op.type) for op in temp_graph.get_operations() if op.op_def.is_stateful]
def _create_definition_if_needed_impl(self): """This is not what you want, see _create_definition_if_needed.""" if self._definition is not None or self._c_func is not None: return # Copy variable collections (by reference) from the parent graph such that # name based variable sharing (e.g. via tf.make_template) works between the # func graph and parent graph. variable_keys = [] variable_keys.extend(ops.GraphKeys._VARIABLE_COLLECTIONS) # pylint: disable=protected-access variable_keys.append(vs._VARSTORE_KEY) # pylint: disable=protected-access collections_ref = {} parent_collections_ref = ops.get_default_graph()._collections # pylint: disable=protected-access for key in variable_keys: if key not in parent_collections_ref: parent_collections_ref[key] = collections_ref[key] = [] else: collections_ref[key] = parent_collections_ref[key] temp_graph = func_graph_from_py_func( self._func, self._arg_names, self._arg_types, self._func_name, self._capture_by_value, self._caller_device, collections_ref=collections_ref, whitelisted_stateful_ops=self._whitelisted_stateful_ops, capture_resource_var_by_value=self._capture_resource_var_by_value) self._extra_inputs = temp_graph.extra_inputs # pylint: disable=protected-access self._sub_functions = temp_graph._functions # pylint: enable=protected-access # Extra kwargs are treated as attrs on the function def. if self._func_name: base_func_name = self._func_name else: base_func_name = function_utils.get_func_name(self._func) if self._grad_func: base_func_name += ("_%s" % self._grad_func.name) kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs) if not temp_graph._c_graph: # pylint: disable=protected-access # Build the FunctionDef self._definition = graph_to_function_def.graph_to_function_def( temp_graph, temp_graph.get_operations(), temp_graph.inputs, temp_graph.outputs, out_names=self._out_names) for k in kwargs_attr: self._definition.attr[k].CopyFrom(kwargs_attr[k]) # Hash the definition and its dependencies. self._hash_str = self._create_hash_str( self._definition.signature.input_arg, self._definition.signature.output_arg, self._definition.node_def) # Finally, we decide the function name to use. If not specified, # make up something which is almost certainly unique (but deterministic). if not self._func_name: self._func_name = "_".join([base_func_name, self._hash_str]) self._definition.signature.name = self._func_name if self._func.__doc__: self._definition.signature.description = self._func.__doc__ self._op_def = self._definition.signature else: # C API is enabled output_names = ([compat.as_bytes(x) for x in self._out_names] if self._out_names else []) description = self._func.__doc__ or None # pylint: disable=protected-access c_func = c_api.TF_GraphToFunction_wrapper( temp_graph._c_graph, base_func_name, self._func_name is None, # append_hash_to_fn_name None, # opers [t._as_tf_output() for t in temp_graph.inputs], [t._as_tf_output() for t in temp_graph.outputs], output_names, [], # control_outputs [], # control_output_names None, # opts description) self._c_func = c_api_util.ScopedTFFunction(c_func) # pylint: enable=protected-access self._set_c_attrs(kwargs_attr) # Set cached fields: _op_def and _func_name (if not already set) self._op_def = self.definition.signature if self._func_name: assert self._func_name == self._op_def.name else: self._func_name = compat.as_str(self._op_def.name) self._stateful_ops = [(op.name, op.type) for op in temp_graph.get_operations() if op.op_def.is_stateful]
def _create_definition_if_needed_impl(self): """This is not what you want, see _create_definition_if_needed.""" if self._definition is not None or self._c_func is not None: return # Create the func_def object. temp_graph = _FuncGraph(capture_by_value=self._capture_by_value) with temp_graph.as_default(): # List of placeholders for the function_def. inputs = [] for (argname, argtype) in self._args: argholder = array_ops.placeholder(argtype, name=argname) inputs.append(argholder) # Call func and gather the output tensors. with vs.variable_scope("", custom_getter=temp_graph.getvar): outputs = self._func(*inputs) # There is no way of distinguishing between a function not returning # anything and a function returning None in Python. # We need to allow the former and ideally want to forbid the latter as # it is most likely user error. # TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to # allow users to explicitly mark the function as not returning anything. # For now, we allow a single None return and interpret it as a function # with no output. if outputs is None: outputs = [] else: # If func only returned one value, make it a tuple. if not isinstance(outputs, (list, tuple)): outputs = (outputs,) if any([_ is None for _ in outputs]): raise ValueError("Function can not return None.") # Ensures each output is a Tensor. outputs = [ops.convert_to_tensor(_) for _ in outputs] self._extra_inputs = temp_graph.extra_inputs inputs.extend(temp_graph.extra_args) # pylint: disable=protected-access self._sub_functions = temp_graph._functions # pylint: enable=protected-access # Extra kwargs are treated as attrs on the function def. base_func_name = self._func_name or _get_func_name(self._func) kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs) if not temp_graph._c_graph: # pylint: disable=protected-access # Build the FunctionDef self._definition = graph_to_function_def.graph_to_function_def( temp_graph, temp_graph.get_operations(), inputs, outputs, out_names=self._out_names) for k in kwargs_attr: self._definition.attr[k].CopyFrom(kwargs_attr[k]) # Hash the definition and its dependencies. self._hash_str = self._create_hash_str( self._definition.signature.input_arg, self._definition.signature.output_arg, self._definition.node_def) # Finally, we decide the function name to use. If not specified, # make up something which is almost certainly unique (but deterministic). if not self._func_name: self._func_name = "_".join([base_func_name, self._hash_str]) self._definition.signature.name = self._func_name if self._func.__doc__: self._definition.signature.description = self._func.__doc__ self._op_def = self._definition.signature else: # C API is enabled output_names = ([compat.as_bytes(x) for x in self._out_names] if self._out_names else []) description = self._func.__doc__ or None # pylint: disable=protected-access with errors.raise_exception_on_not_ok_status() as status: self._c_func = c_api.TF_GraphToFunction_wrapper( temp_graph._c_graph, base_func_name, self._func_name is None, # append_hash_to_fn_name None, # opers [t._as_tf_output() for t in inputs], [t._as_tf_output() for t in outputs], output_names, None, # opts description, status) # pylint: enable=protected-access self._set_c_attrs(kwargs_attr) # Set cached fields: _op_def and _func_name (if not already set) self._op_def = self.definition.signature if self._func_name: assert self._func_name == self._op_def.name else: self._func_name = compat.as_str(self._op_def.name)