def test_transform_nodedef_fn(self): transformer = gde.Transformer() def nodedef_fn(node_def): if "_foo" in node_def.attr: del node_def.attr["_foo"] node_def.attr["_bar"].s = b"bar" return node_def my_copy_op_handler = functools.partial( gde.transform.copy_op_handler, nodedef_fn=nodedef_fn) transformer.transform_op_handler = my_copy_op_handler graph = gde.Graph() transformer(self.graph, graph, "", "") c0_before = self.graph["Const"] c0_after = graph["Const"] self.assertEqual(c0_before.get_attr("_foo"), "foo") with self.assertRaises(ValueError): c0_after.get_attr("_foo") all_ops = graph.nodes for op in all_ops: self.assertEqual(op.get_attr("_bar"), "bar")
def test_transform(self): transformer = gde.Transformer() def my_transform_op_handler(info, op, new_inputs): add_noise = op.name.startswith("Add") op_, op_outputs_ = gde.transform.copy_op_handler( info, op, new_inputs) if not add_noise: return op_, op_outputs_ # add some noise to op # Old code: # with info.graph_.as_default(): # t_ = math_ops.add( # constant_op.constant(1.0, shape=[10], name="Noise"), # op_.outputs[0], # name="AddNoise") noise_op = gde.make_const(info.graph_, "Noise", np.full([10], 1., dtype=np.float32), uniquify_name=True) add_noise_op = info.graph_.add_node("AddNoise", "Add", uniquify_name=True) add_noise_op.add_attr("T", tf.float32) add_noise_op.set_inputs([noise_op.outputs[0], op_.outputs[0]]) add_noise_op.infer_outputs() t_ = add_noise_op.outputs[0] # return the "noisy" op return op_, [t_] transformer.transform_op_handler = my_transform_op_handler graph = gde.Graph() transformer(self.graph, graph, "", "") matcher0 = gde.OpMatcher("AddNoise").input_ops( "Noise", gde.OpMatcher("Add").input_ops("Const", "Input")) matcher1 = gde.OpMatcher("AddNoise_1").input_ops( "Noise_1", gde.OpMatcher("Add_1").input_ops("Const_1", matcher0)) matcher2 = gde.OpMatcher("AddNoise_2").input_ops( "Noise_2", gde.OpMatcher("Add_2").input_ops("Const_2", matcher1)) top = gde.select_ops("^AddNoise_2$", graph=graph)[0] self.assertTrue(matcher2(top))
def test_copy_assert(self): tf_g = tf.Graph() with tf_g.as_default(): a = tf.constant(1, name="a") b = tf.constant(1, name="b") eq = tf.equal(a, b, name="EQ") assert_tf_op = tf.Assert(eq, [a, b]) with tf.control_dependencies([assert_tf_op]): _ = tf.add(a, b) assert_op_name = assert_tf_op.name g = gde.Graph(tf_g) assert_op = g[assert_op_name] sgv = gde.make_view([assert_op, g["EQ"], g["a"], g["b"]]) copier = gde.Transformer() _, info = copier(sgv, sgv.graph, "", "") new_assert_op = info.transformed(assert_op) self.assertIsNotNone(new_assert_op)