def test_reroute(self): gde.reroute_ts([self.a0, self.b0], [self.a1, self.b1]) self.assertTrue(gde.OpMatcher("c0").input_ops("a0", "b0")(self.c0.op)) self.assertTrue(gde.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op)) gde.reroute_ts([self.a1, self.b1], [self.a0, self.b0]) self.assertTrue(gde.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op)) self.assertTrue(gde.OpMatcher("c1").input_ops("a1", "b1")(self.c1.op))
def test_multiswap(self): # Original code: # with self.graph.as_default(): # a3 = constant_op.constant(3.0, shape=[2], name="a3") # New code adds a NodeDef to the graph: a3_node = gde.make_const(self.graph, "a3", np.full([2], 3.0, dtype=np.float32)) gde.swap_ios( gde.sgv(a3_node).remap_outputs([0, 0]), gde.sgv(self.a0.op, self.a1.op)) self.assertTrue(gde.OpMatcher("c0").input_ops("a3", "b0")(self.c0.op)) self.assertTrue(gde.OpMatcher("c1").input_ops("a3", "b1")(self.c1.op))
def test_detach(self): """Test for ge.detach.""" sgv = gde.sgv(self.c.op, self.a.op) control_outputs = gde.ControlOutputs(self.graph) gde.detach(sgv, control_ios=control_outputs) # make sure the detached graph is as expected. self.assertTrue( gde.OpMatcher("^foo/c$").input_ops("a", "geph__b_0")(self.c.op))
def test_simple_match(self): self.assertTrue(gde.OpMatcher("^.*/f$")(self.f_op)) self.assertTrue( gde.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$")(self.f_op)) self.assertTrue( gde.OpMatcher("^.*/f$").input_ops(True, "^.*/d$")(self.f_op)) self.assertTrue( gde.OpMatcher("^.*/f$").input_ops(gde.op_type("Add"), gde.op_type("Const"))(self.f_op) or gde.OpMatcher("^.*/f$").input_ops( gde.op_type("AddV2"), gde.op_type("Const"))(self.f_op)) self.assertTrue( gde.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$").output_ops( gde.OpMatcher("^.*/h$").control_input_ops("^.*/c$"))( self.f_op)) self.assertTrue( gde.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$").output_ops( gde.OpMatcher("^.*/h$").control_input_ops("^.*/c$").output_ops( []))(self.f_op))
def test_reroute_can_modify(self): # create a special graph where "a" is an ambiguous tensor. That is # it is both an input and an output of the ops in sgv0. tf_graph = tf.Graph() with tf_graph.as_default(): a_tensor = tf.constant(1.0, shape=[2], name="a") b_tensor = tf.constant(2.0, shape=[2], name="b") c_tensor = tf.add(a_tensor, b_tensor, name="c") _ = tf.add(a_tensor, c_tensor, name="d") e_tensor = tf.constant(1.0, shape=[2], name="e") f_tensor = tf.constant(2.0, shape=[2], name="f") _ = tf.add(e_tensor, f_tensor, name="g") g = gde.Graph(tf_graph) sgv0 = gde.sgv(g["a"], g["b"], g["c"]) sgv1 = gde.sgv(g["e"], g["f"]) gde.swap_outputs(sgv0, sgv1) self.assertTrue( gde.OpMatcher("g").input_ops( "a", gde.OpMatcher("c").input_ops("a", "b"))(g["g"])) self.assertTrue(gde.OpMatcher("d").input_ops("e", "f")(g["d"]))
def test_connect(self): """Test for gde.connect.""" # Original code: # with self.graph.as_default(): # x = constant_op.constant([1., 1.], shape=[2], name="x") # y = constant_op.constant([2., 2.], shape=[2], name="y") # z = math_ops.add(x, y, name="z") x = gde.make_const(self.graph, "x", np.array([1., 1.], dtype=np.float32)) y = gde.make_const(self.graph, "y", np.array([2., 2.], dtype=np.float32)) z = self.graph.add_node("z", "Add") z.add_attr("T", tf.float32) z.set_inputs([x.outputs[0], y.outputs[0]]) z.infer_outputs() sgv = gde.sgv(x, y, z) gde.connect(sgv, gde.sgv(self.e.op).remap_inputs([0])) self.assertTrue( gde.OpMatcher("^foo/bar/e$").input_ops("^z$", "foo/d$")(self.e.op))
def test_transform(self): transformer = gde.Transformer() def my_transform_op_handler(info, op, new_inputs): add_noise = op.name.startswith("Add") op_, op_outputs_ = gde.transform.copy_op_handler( info, op, new_inputs) if not add_noise: return op_, op_outputs_ # add some noise to op # Old code: # with info.graph_.as_default(): # t_ = math_ops.add( # constant_op.constant(1.0, shape=[10], name="Noise"), # op_.outputs[0], # name="AddNoise") noise_op = gde.make_const(info.graph_, "Noise", np.full([10], 1., dtype=np.float32), uniquify_name=True) add_noise_op = info.graph_.add_node("AddNoise", "Add", uniquify_name=True) add_noise_op.add_attr("T", tf.float32) add_noise_op.set_inputs([noise_op.outputs[0], op_.outputs[0]]) add_noise_op.infer_outputs() t_ = add_noise_op.outputs[0] # return the "noisy" op return op_, [t_] transformer.transform_op_handler = my_transform_op_handler graph = gde.Graph() transformer(self.graph, graph, "", "") matcher0 = gde.OpMatcher("AddNoise").input_ops( "Noise", gde.OpMatcher("Add").input_ops("Const", "Input")) matcher1 = gde.OpMatcher("AddNoise_1").input_ops( "Noise_1", gde.OpMatcher("Add_1").input_ops("Const_1", matcher0)) matcher2 = gde.OpMatcher("AddNoise_2").input_ops( "Noise_2", gde.OpMatcher("Add_2").input_ops("Const_2", matcher1)) top = gde.select_ops("^AddNoise_2$", graph=graph)[0] self.assertTrue(matcher2(top))
def test_bypass(self): """Test for ge.bypass.""" gde.bypass(gde.sgv(self.f.op).remap_inputs([0])) self.assertTrue( gde.OpMatcher("^foo/bar/h$").input_ops("^foo/c$", "foo/bar/g$")( self.h.op))