def test_transform_nodedef_fn(self):
    transformer = gde.Transformer()

    def nodedef_fn(node_def):
      if "_foo" in node_def.attr:
        del node_def.attr["_foo"]
      node_def.attr["_bar"].s = b"bar"
      return node_def

    my_copy_op_handler = functools.partial(
        gde.transform.copy_op_handler, nodedef_fn=nodedef_fn)
    transformer.transform_op_handler = my_copy_op_handler

    graph = gde.Graph()
    transformer(self.graph, graph, "", "")

    c0_before = self.graph["Const"]
    c0_after = graph["Const"]
    self.assertEqual(c0_before.get_attr("_foo"), "foo")
    with self.assertRaises(ValueError):
      c0_after.get_attr("_foo")

    all_ops = graph.nodes
    for op in all_ops:
      self.assertEqual(op.get_attr("_bar"), "bar")
Esempio n. 2
0
    def test_transform(self):
        transformer = gde.Transformer()

        def my_transform_op_handler(info, op, new_inputs):
            add_noise = op.name.startswith("Add")
            op_, op_outputs_ = gde.transform.copy_op_handler(
                info, op, new_inputs)
            if not add_noise:
                return op_, op_outputs_

            # add some noise to op
            # Old code:
            # with info.graph_.as_default():
            #   t_ = math_ops.add(
            #       constant_op.constant(1.0, shape=[10], name="Noise"),
            #       op_.outputs[0],
            #       name="AddNoise")
            noise_op = gde.make_const(info.graph_,
                                      "Noise",
                                      np.full([10], 1., dtype=np.float32),
                                      uniquify_name=True)
            add_noise_op = info.graph_.add_node("AddNoise",
                                                "Add",
                                                uniquify_name=True)
            add_noise_op.add_attr("T", tf.float32)
            add_noise_op.set_inputs([noise_op.outputs[0], op_.outputs[0]])
            add_noise_op.infer_outputs()
            t_ = add_noise_op.outputs[0]

            # return the "noisy" op
            return op_, [t_]

        transformer.transform_op_handler = my_transform_op_handler

        graph = gde.Graph()
        transformer(self.graph, graph, "", "")
        matcher0 = gde.OpMatcher("AddNoise").input_ops(
            "Noise",
            gde.OpMatcher("Add").input_ops("Const", "Input"))
        matcher1 = gde.OpMatcher("AddNoise_1").input_ops(
            "Noise_1",
            gde.OpMatcher("Add_1").input_ops("Const_1", matcher0))
        matcher2 = gde.OpMatcher("AddNoise_2").input_ops(
            "Noise_2",
            gde.OpMatcher("Add_2").input_ops("Const_2", matcher1))
        top = gde.select_ops("^AddNoise_2$", graph=graph)[0]
        self.assertTrue(matcher2(top))
  def test_copy_assert(self):
    tf_g = tf.Graph()
    with tf_g.as_default():
      a = tf.constant(1, name="a")
      b = tf.constant(1, name="b")
      eq = tf.equal(a, b, name="EQ")
      assert_tf_op = tf.Assert(eq, [a, b])
      with tf.control_dependencies([assert_tf_op]):
        _ = tf.add(a, b)
    assert_op_name = assert_tf_op.name

    g = gde.Graph(tf_g)
    assert_op = g[assert_op_name]
    sgv = gde.make_view([assert_op, g["EQ"], g["a"], g["b"]])
    copier = gde.Transformer()
    _, info = copier(sgv, sgv.graph, "", "")
    new_assert_op = info.transformed(assert_op)
    self.assertIsNotNone(new_assert_op)