def testGroup_MultiDevice(self): with ops.Graph().as_default() as g: with g.device("/task:0"): a = tf.constant(0, name="a") b = tf.constant(0, name="b") with g.device("/task:1"): c = tf.constant(0, name="c") d = tf.constant(0, name="d") with g.device("/task:2"): tf.group(a.op, b.op, c.op, d.op, name="root") gd = g.as_graph_def() self.assertProtoEquals( """ node { name: "a" op: "Const" device: "/task:0"} node { name: "b" op: "Const" device: "/task:0"} node { name: "c" op: "Const" device: "/task:1"} node { name: "d" op: "Const" device: "/task:1"} node { name: "root/NoOp" op: "NoOp" input: "^a" input: "^b" device: "/task:0" } node { name: "root/NoOp_1" op: "NoOp" input: "^c" input: "^d" device: "/task:1" } node { name: "root" op: "NoOp" input: "^root/NoOp" input: "^root/NoOp_1" device: "/task:2" } """, self._StripGraph(gd), )
def testGroup_OneDevice(self): with ops.Graph().as_default() as g: with g.device("/task:0"): a = tf.constant(0, name="a") b = tf.constant(0, name="b") tf.group(a.op, b.op, name="root") gd = g.as_graph_def() self.assertProtoEquals(""" node { name: "a" op: "Const" device: "/task:0" } node { name: "b" op: "Const" device: "/task:0" } node { name: "root" op: "NoOp" input: "^a" input: "^b" device: "/task:0" } """, self._StripGraph(gd))
def testGroup_NoDevices(self): with ops.Graph().as_default() as g: a = tf.constant(0, name="a") b = tf.constant(0, name="b") c = tf.constant(0, name="c") tf.group(a.op, b.op, c.op, name="root") gd = g.as_graph_def() self.assertProtoEquals(""" node { name: "a" op: "Const"} node { name: "b" op: "Const"} node { name: "c" op: "Const"} node { name: "root" op: "NoOp" input: "^a" input: "^b" input: "^c" } """, self._StripGraph(gd))
def testGroup_MultiDevice(self): with ops.Graph().as_default() as g: with g.device("/task:0"): a = tf.constant(0, name="a") b = tf.constant(0, name="b") with g.device("/task:1"): c = tf.constant(0, name="c") d = tf.constant(0, name="d") with g.device("/task:2"): tf.group(a.op, b.op, c.op, d.op, name="root") gd = g.as_graph_def() self.assertProtoEquals(""" node { name: "a" op: "Const" device: "/task:0"} node { name: "b" op: "Const" device: "/task:0"} node { name: "c" op: "Const" device: "/task:1"} node { name: "d" op: "Const" device: "/task:1"} node { name: "root/NoOp" op: "NoOp" input: "^a" input: "^b" device: "/task:0" } node { name: "root/NoOp_1" op: "NoOp" input: "^c" input: "^d" device: "/task:1" } node { name: "root" op: "NoOp" input: "^root/NoOp" input: "^root/NoOp_1" device: "/task:2" } """, self._StripGraph(gd))