def testBasic(self): g = tf.Graph() # Define a function # foo(a:float, b:float, c:float)->u:float,v:float,w:float # u = matmul(a, b) + c # v = u^2 # w = u + v # TODO(zhifengc): replaces w/ a nicer @decorator sugar. foo = tf.Graph() with foo.as_default(): a = tf.placeholder(tf.float32, name="a") b = tf.placeholder(tf.float32, name="b") c = tf.placeholder(tf.float32, name="c") u = tf.add(tf.matmul(a, b), c, name="u") v = tf.square(u, name="v") w = tf.add_n([u, v], name="w") fdef = function.graph_to_function_def(foo, "foo", [a, b, c], [u, v, w]) g._add_function(fdef) # Compute 2 * 3 + 4 and its square. with g.as_default(), tf.Session() as sess: two = tf.constant(self._mat(2.0), name="two") three = tf.constant(self._mat(3.0), name="three") four = tf.constant(self._mat(4.0), name="four") # TODO(zhifengc): w/ @decorator sugar, we will just do: # y, s, t = foo_func(two, three, four) # The graph contains two ops each of which calls foo. u0, v0, w0 = g.create_op("foo", [two, three, four], [tf.float32, tf.float32, tf.float32], compute_shapes=False).outputs u1, v1, w1 = g.create_op("foo", [four, two, three], [tf.float32, tf.float32, tf.float32], compute_shapes=False).outputs # Checks some property of the graph def. gdef = g.as_graph_def() self.assertEqual(len(gdef.node), 5) # 5 nodes added. self.assertEqual(len(gdef.library.function), 1) # 1 function is defined. for _ in xrange(10): # Run the graph, which is basicly two function calls. ans_u0, ans_v0, ans_w0, ans_u1, ans_v1, ans_w1 = sess.run([u0, v0, w0, u1, v1, w1]) self.assertAllEqual(ans_u0, self._mat(10.0)) # 2 * 3 + 4 = 10 self.assertAllEqual(ans_v0, self._mat(100.0)) # 10^2 = 100 self.assertAllEqual(ans_w0, self._mat(110.0)) # 100 + 10 = 110 self.assertAllEqual(ans_u1, self._mat(11.0)) # 4 * 2 + 3 = 11 self.assertAllEqual(ans_v1, self._mat(121.0)) # 11^2 = 121 self.assertAllEqual(ans_w1, self._mat(132.0)) # 11 + 121 = 132
def testBasic(self): g = tf.Graph() # Define a function # foo(a:float, b:float, c:float)->u:float,v:float,w:float # u = matmul(a, b) + c # v = u^2 # w = u + v # TODO(zhifengc): replaces w/ a nicer @decorator sugar. foo = tf.Graph() with foo.as_default(): a = tf.placeholder(tf.float32, name="a") b = tf.placeholder(tf.float32, name="b") c = tf.placeholder(tf.float32, name="c") u = tf.add(tf.matmul(a, b), c, name="u") v = tf.square(u, name="v") w = tf.add_n([u, v], name="w") fdef = function.graph_to_function_def(foo, "foo", [a, b, c], [u, v, w]) g._add_function(fdef) # Compute 2 * 3 + 4 and its square. with g.as_default(), tf.Session() as sess: two = tf.constant(self._mat(2.0), name="two") three = tf.constant(self._mat(3.0), name="three") four = tf.constant(self._mat(4.0), name="four") # TODO(zhifengc): w/ @decorator sugar, we will just do: # y, s, t = foo_func(two, three, four) # The graph contains two ops each of which calls foo. u0, v0, w0 = g.create_op("foo", [two, three, four], [tf.float32, tf.float32, tf.float32], compute_shapes=False).outputs u1, v1, w1 = g.create_op("foo", [four, two, three], [tf.float32, tf.float32, tf.float32], compute_shapes=False).outputs # Checks some property of the graph def. gdef = g.as_graph_def() self.assertEqual(len(gdef.node), 5) # 5 nodes added. self.assertEqual(len(gdef.library.function), 1) # 1 function is defined. for _ in xrange(10): # Run the graph, which is basicly two function calls. ans_u0, ans_v0, ans_w0, ans_u1, ans_v1, ans_w1 = sess.run([u0, v0, w0, u1, v1, w1]) self.assertAllEqual(ans_u0, self._mat(10.0)) # 2 * 3 + 4 = 10 self.assertAllEqual(ans_v0, self._mat(100.0)) # 10^2 = 100 self.assertAllEqual(ans_w0, self._mat(110.0)) # 100 + 10 = 110 self.assertAllEqual(ans_u1, self._mat(11.0)) # 4 * 2 + 3 = 11 self.assertAllEqual(ans_v1, self._mat(121.0)) # 11^2 = 121 self.assertAllEqual(ans_w1, self._mat(132.0)) # 11 + 121 = 132