def test_reroute(self):
        gde.reroute_ts([self.a0, self.b0], [self.a1, self.b1])
        self.assertTrue(gde.OpMatcher("c0").input_ops("a0", "b0")(self.c0.op))
        self.assertTrue(gde.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op))

        gde.reroute_ts([self.a1, self.b1], [self.a0, self.b0])
        self.assertTrue(gde.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op))
        self.assertTrue(gde.OpMatcher("c1").input_ops("a1", "b1")(self.c1.op))
    def test_multiswap(self):
        # Original code:
        # with self.graph.as_default():
        #   a3 = constant_op.constant(3.0, shape=[2], name="a3")
        # New code adds a NodeDef to the graph:
        a3_node = gde.make_const(self.graph, "a3",
                                 np.full([2], 3.0, dtype=np.float32))

        gde.swap_ios(
            gde.sgv(a3_node).remap_outputs([0, 0]),
            gde.sgv(self.a0.op, self.a1.op))
        self.assertTrue(gde.OpMatcher("c0").input_ops("a3", "b0")(self.c0.op))
        self.assertTrue(gde.OpMatcher("c1").input_ops("a3", "b1")(self.c1.op))
Beispiel #3
0
 def test_detach(self):
   """Test for ge.detach."""
   sgv = gde.sgv(self.c.op, self.a.op)
   control_outputs = gde.ControlOutputs(self.graph)
   gde.detach(sgv, control_ios=control_outputs)
   # make sure the detached graph is as expected.
   self.assertTrue(
       gde.OpMatcher("^foo/c$").input_ops("a", "geph__b_0")(self.c.op))
Beispiel #4
0
 def test_simple_match(self):
     self.assertTrue(gde.OpMatcher("^.*/f$")(self.f_op))
     self.assertTrue(
         gde.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$")(self.f_op))
     self.assertTrue(
         gde.OpMatcher("^.*/f$").input_ops(True, "^.*/d$")(self.f_op))
     self.assertTrue(
         gde.OpMatcher("^.*/f$").input_ops(gde.op_type("Add"),
                                           gde.op_type("Const"))(self.f_op)
         or gde.OpMatcher("^.*/f$").input_ops(
             gde.op_type("AddV2"), gde.op_type("Const"))(self.f_op))
     self.assertTrue(
         gde.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$").output_ops(
             gde.OpMatcher("^.*/h$").control_input_ops("^.*/c$"))(
                 self.f_op))
     self.assertTrue(
         gde.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$").output_ops(
             gde.OpMatcher("^.*/h$").control_input_ops("^.*/c$").output_ops(
                 []))(self.f_op))
    def test_reroute_can_modify(self):
        # create a special graph where "a" is an ambiguous tensor. That is
        # it is both an input and an output of the ops in sgv0.
        tf_graph = tf.Graph()
        with tf_graph.as_default():
            a_tensor = tf.constant(1.0, shape=[2], name="a")
            b_tensor = tf.constant(2.0, shape=[2], name="b")
            c_tensor = tf.add(a_tensor, b_tensor, name="c")
            _ = tf.add(a_tensor, c_tensor, name="d")
            e_tensor = tf.constant(1.0, shape=[2], name="e")
            f_tensor = tf.constant(2.0, shape=[2], name="f")
            _ = tf.add(e_tensor, f_tensor, name="g")
        g = gde.Graph(tf_graph)

        sgv0 = gde.sgv(g["a"], g["b"], g["c"])
        sgv1 = gde.sgv(g["e"], g["f"])

        gde.swap_outputs(sgv0, sgv1)
        self.assertTrue(
            gde.OpMatcher("g").input_ops(
                "a",
                gde.OpMatcher("c").input_ops("a", "b"))(g["g"]))
        self.assertTrue(gde.OpMatcher("d").input_ops("e", "f")(g["d"]))
Beispiel #6
0
  def test_connect(self):
    """Test for gde.connect."""
    # Original code:
    # with self.graph.as_default():
    #   x = constant_op.constant([1., 1.], shape=[2], name="x")
    #   y = constant_op.constant([2., 2.], shape=[2], name="y")
    #   z = math_ops.add(x, y, name="z")
    x = gde.make_const(self.graph, "x", np.array([1., 1.], dtype=np.float32))
    y = gde.make_const(self.graph, "y", np.array([2., 2.], dtype=np.float32))
    z = self.graph.add_node("z", "Add")
    z.add_attr("T", tf.float32)
    z.set_inputs([x.outputs[0], y.outputs[0]])
    z.infer_outputs()

    sgv = gde.sgv(x, y, z)
    gde.connect(sgv, gde.sgv(self.e.op).remap_inputs([0]))
    self.assertTrue(
        gde.OpMatcher("^foo/bar/e$").input_ops("^z$", "foo/d$")(self.e.op))
    def test_transform(self):
        transformer = gde.Transformer()

        def my_transform_op_handler(info, op, new_inputs):
            add_noise = op.name.startswith("Add")
            op_, op_outputs_ = gde.transform.copy_op_handler(
                info, op, new_inputs)
            if not add_noise:
                return op_, op_outputs_

            # add some noise to op
            # Old code:
            # with info.graph_.as_default():
            #   t_ = math_ops.add(
            #       constant_op.constant(1.0, shape=[10], name="Noise"),
            #       op_.outputs[0],
            #       name="AddNoise")
            noise_op = gde.make_const(info.graph_,
                                      "Noise",
                                      np.full([10], 1., dtype=np.float32),
                                      uniquify_name=True)
            add_noise_op = info.graph_.add_node("AddNoise",
                                                "Add",
                                                uniquify_name=True)
            add_noise_op.add_attr("T", tf.float32)
            add_noise_op.set_inputs([noise_op.outputs[0], op_.outputs[0]])
            add_noise_op.infer_outputs()
            t_ = add_noise_op.outputs[0]

            # return the "noisy" op
            return op_, [t_]

        transformer.transform_op_handler = my_transform_op_handler

        graph = gde.Graph()
        transformer(self.graph, graph, "", "")
        matcher0 = gde.OpMatcher("AddNoise").input_ops(
            "Noise",
            gde.OpMatcher("Add").input_ops("Const", "Input"))
        matcher1 = gde.OpMatcher("AddNoise_1").input_ops(
            "Noise_1",
            gde.OpMatcher("Add_1").input_ops("Const_1", matcher0))
        matcher2 = gde.OpMatcher("AddNoise_2").input_ops(
            "Noise_2",
            gde.OpMatcher("Add_2").input_ops("Const_2", matcher1))
        top = gde.select_ops("^AddNoise_2$", graph=graph)[0]
        self.assertTrue(matcher2(top))
Beispiel #8
0
 def test_bypass(self):
   """Test for ge.bypass."""
   gde.bypass(gde.sgv(self.f.op).remap_inputs([0]))
   self.assertTrue(
       gde.OpMatcher("^foo/bar/h$").input_ops("^foo/c$", "foo/bar/g$")(
           self.h.op))