def _create_definition_if_needed(self): """Creates the function definition if it's not created yet.""" if self._definition is not None: return # Create the func_def object. temp_graph = _ExperimentalFuncGraph( capture_by_value=self._capture_by_value) with temp_graph.as_default(): # List of placeholders for the function_def. inputs = [] for (argname, argtype) in self._args: argholder = array_ops.placeholder(argtype, name=argname) inputs.append(argholder) # Call func and gather the output tensors. with vs.variable_scope("", custom_getter=temp_graph.getvar): outputs = self._func(*inputs) # If func only returned one value, make it a tuple. if not isinstance(outputs, (list, tuple)): outputs = (outputs, ) if any([_ is None for _ in outputs]): raise ValueError("Function can not return None.") # Ensures each output is a Tensor. outputs = [ops.convert_to_tensor(_) for _ in outputs] self._extra_inputs = temp_graph.extra_inputs inputs.extend(temp_graph.extra_args) self._sub_functions = temp_graph._functions # Build the FunctionDef self._definition = function._graph_to_function_def( temp_graph, temp_graph.get_operations(), inputs, outputs, out_names=self._out_names) # Extra kwargs are treated as attrs on the function def. sig_pre_func_name = self._func_name or function._get_func_name( self._func) kwargs_attr = function._parse_kwargs_as_attrs(sig_pre_func_name, **self._extra_kwargs) for k in kwargs_attr: self._definition.attr[k].CopyFrom(kwargs_attr[k]) # Hash the definition and its dependencies. self._hash_str = self._create_hash_str( self._definition.signature.input_arg, self._definition.signature.output_arg, self._definition.node_def) # Finally, we decide the function name to use. If not specified, # make up something which is almost certainly unique (but deterministic). if not self._func_name: self._func_name = "_".join( [function._get_func_name(self._func), self._hash_str]) self._definition.signature.name = self._func_name if self._func.__doc__: self._definition.signature.description = self._func.__doc__
def _create_definition_if_needed(self): """Creates the function definition if it's not created yet.""" if self._definition is not None: return # Create the func_def object. temp_graph = _ExperimentalFuncGraph(capture_by_value=self._capture_by_value) with temp_graph.as_default(): # List of placeholders for the function_def. inputs = [] for (argname, argtype) in self._args: argholder = array_ops.placeholder(argtype, name=argname) inputs.append(argholder) # Call func and gather the output tensors. with vs.variable_scope("", custom_getter=temp_graph.getvar): outputs = self._func(*inputs) # If func only returned one value, make it a tuple. if not isinstance(outputs, (list, tuple)): outputs = (outputs,) if any([_ is None for _ in outputs]): raise ValueError("Function can not return None.") # Ensures each output is a Tensor. outputs = [ops.convert_to_tensor(_) for _ in outputs] self._extra_inputs = temp_graph.extra_inputs inputs.extend(temp_graph.extra_args) self._sub_functions = temp_graph._functions # Build the FunctionDef self._definition = function._graph_to_function_def( temp_graph, temp_graph.get_operations(), inputs, outputs, out_names=self._out_names) # Extra kwargs are treated as attrs on the function def. sig_pre_func_name = self._func_name or function._get_func_name(self._func) kwargs_attr = function._parse_kwargs_as_attrs( sig_pre_func_name, **self._extra_kwargs) for k in kwargs_attr: self._definition.attr[k].CopyFrom(kwargs_attr[k]) # Hash the definition and its dependencies. self._hash_str = self._create_hash_str( self._definition.signature.input_arg, self._definition.signature.output_arg, self._definition.node_def) # Finally, we decide the function name to use. If not specified, # make up something which is almost certainly unique (but deterministic). if not self._func_name: self._func_name = "_".join([function._get_func_name(self._func), self._hash_str]) self._definition.signature.name = self._func_name if self._func.__doc__: self._definition.signature.description = self._func.__doc__
def testTwoInputsSameOp(self): g = ops.Graph() with g.as_default(): m = array_ops.placeholder(dtypes.float32) s, u, v = linalg_ops.svd(m) ss = math_ops.reduce_sum(s) uu = math_ops.reduce_sum(u) vv = math_ops.reduce_sum(v) result = ss + uu + vv f = function._graph_to_function_def( g, g.get_operations()[1:], # skip the placeholder [s, u, v], [result]) self.assertEqual(len(f.signature.input_arg), 3)
def testBasic(self): g = tf.Graph() # Define a function # foo(a:float, b:float, c:float)->u:float,v:float,w:float # u = matmul(a, b) + c # v = u^2 # w = u + v foo = tf.Graph() with foo.as_default(): a = tf.placeholder(tf.float32, name="a") b = tf.placeholder(tf.float32, name="b") c = tf.placeholder(tf.float32, name="c") u = tf.add(tf.matmul(a, b), c, name="u") v = tf.square(u, name="v") w = tf.add_n([u, v], name="w") fdef = function._graph_to_function_def(foo, "foo", [a, b, c], [u, v, w]) class Mock(function._DefinedFunction): def __init__(self, fdef): self._func_name = "foo" self._definition = fdef self._sub_functions = collections.OrderedDict() self._grad_func = None self._python_grad_func = None self._hash = hash(fdef.SerializeToString()) g._add_function(Mock(fdef)) # Compute 2 * 3 + 4 and its square. with g.as_default(), tf.Session() as sess: two = tf.constant(self._mat(2.0), name="two") three = tf.constant(self._mat(3.0), name="three") four = tf.constant(self._mat(4.0), name="four") # TODO(zhifengc): w/ @decorator sugar, we will just do: # y, s, t = foo_func(two, three, four) # The graph contains two ops each of which calls foo. u0, v0, w0 = g.create_op( "foo", [two, three, four], [tf.float32, tf.float32, tf.float32], compute_shapes=False).outputs u1, v1, w1 = g.create_op( "foo", [four, two, three], [tf.float32, tf.float32, tf.float32], compute_shapes=False).outputs # Checks some property of the graph def. gdef = g.as_graph_def() self.assertEqual(len(gdef.node), 5) # 5 nodes added. self.assertEqual(len(gdef.library.function), 1) # 1 function is defined. for _ in xrange(10): # Run the graph, which is basically two function calls. ans_u0, ans_v0, ans_w0, ans_u1, ans_v1, ans_w1 = sess.run([u0, v0, w0, u1, v1, w1]) self.assertAllEqual(ans_u0, self._mat(10.0)) # 2 * 3 + 4 = 10 self.assertAllEqual(ans_v0, self._mat(100.0)) # 10^2 = 100 self.assertAllEqual(ans_w0, self._mat(110.0)) # 100 + 10 = 110 self.assertAllEqual(ans_u1, self._mat(11.0)) # 4 * 2 + 3 = 11 self.assertAllEqual(ans_v1, self._mat(121.0)) # 11^2 = 121 self.assertAllEqual(ans_w1, self._mat(132.0)) # 11 + 121 = 132
def testBasic(self): g = tf.Graph() # Define a function # foo(a:float, b:float, c:float)->u:float,v:float,w:float # u = matmul(a, b) + c # v = u^2 # w = u + v foo = tf.Graph() with foo.as_default(): a = tf.placeholder(tf.float32, name="a") b = tf.placeholder(tf.float32, name="b") c = tf.placeholder(tf.float32, name="c") u = tf.add(tf.matmul(a, b), c, name="u") v = tf.square(u, name="v") w = tf.add_n([u, v], name="w") fdef = function._graph_to_function_def(foo, "foo", [a, b, c], [u, v, w]) class Mock(function._DefinedFunction): def __init__(self, fdef): self._func_name = "foo" self._definition = fdef self._sub_functions = collections.OrderedDict() self._grad_func = None self._python_grad_func = None self._hash = hash(fdef.SerializeToString()) g._add_function(Mock(fdef)) # Compute 2 * 3 + 4 and its square. with g.as_default(), tf.Session() as sess: two = tf.constant(self._mat(2.0), name="two") three = tf.constant(self._mat(3.0), name="three") four = tf.constant(self._mat(4.0), name="four") # TODO(zhifengc): w/ @decorator sugar, we will just do: # y, s, t = foo_func(two, three, four) # The graph contains two ops each of which calls foo. u0, v0, w0 = g.create_op("foo", [two, three, four], [tf.float32, tf.float32, tf.float32], compute_shapes=False).outputs u1, v1, w1 = g.create_op("foo", [four, two, three], [tf.float32, tf.float32, tf.float32], compute_shapes=False).outputs # Checks some property of the graph def. gdef = g.as_graph_def() self.assertEqual(len(gdef.node), 5) # 5 nodes added. self.assertEqual(len(gdef.library.function), 1) # 1 function is defined. for _ in xrange(10): # Run the graph, which is basically two function calls. ans_u0, ans_v0, ans_w0, ans_u1, ans_v1, ans_w1 = sess.run( [u0, v0, w0, u1, v1, w1]) self.assertAllEqual(ans_u0, self._mat(10.0)) # 2 * 3 + 4 = 10 self.assertAllEqual(ans_v0, self._mat(100.0)) # 10^2 = 100 self.assertAllEqual(ans_w0, self._mat(110.0)) # 100 + 10 = 110 self.assertAllEqual(ans_u1, self._mat(11.0)) # 4 * 2 + 3 = 11 self.assertAllEqual(ans_v1, self._mat(121.0)) # 11^2 = 121 self.assertAllEqual(ans_w1, self._mat(132.0)) # 11 + 121 = 132