Ejemplo n.º 1
0
 def test_swap(self):
     """Test swap."""
     ge.swap_ts([self.a0, self.b0], [self.a1, self.b1])
     self.assertTrue(
         match.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op))
     self.assertTrue(
         match.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op))
Ejemplo n.º 2
0
 def test_multiswap(self):
     """Test multi swap."""
     with self.graph.as_default():
         a3 = constant_op.constant(3.0, shape=[2], name="a3")
     ge.swap_ios(
         ge.sgv(a3.op).remap_outputs([0, 0]),
         ge.sgv(self.a0.op, self.a1.op))
     self.assertTrue(
         match.OpMatcher("c0").input_ops("a3", "b0")(self.c0.op))
     self.assertTrue(
         match.OpMatcher("c1").input_ops("a3", "b1")(self.c1.op))
Ejemplo n.º 3
0
    def test_reroute(self):
        """Test reroute."""
        ge.reroute_ts([self.a0, self.b0], [self.a1, self.b1])
        self.assertTrue(
            match.OpMatcher("c0").input_ops("a0", "b0")(self.c0.op))
        self.assertTrue(
            match.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op))

        ge.reroute_ts([self.a1, self.b1], [self.a0, self.b0])
        self.assertTrue(
            match.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op))
        self.assertTrue(
            match.OpMatcher("c1").input_ops("a1", "b1")(self.c1.op))
Ejemplo n.º 4
0
 def test_detach(self):
   """Test for ge.detach."""
   sgv = ge.sgv(self.c.op, self.a.op)
   control_outputs = ge.ControlOutputs(self.graph)
   ge.detach(sgv, control_ios=control_outputs)
   # make sure the detached graph is as expected.
   self.assertTrue(
       match.OpMatcher("^foo/c$").input_ops("a", "geph__b_0")(self.c.op))
Ejemplo n.º 5
0
  def test_connect(self):
    """Test for ge.connect."""
    with self.graph.as_default():
      x = constant_op.constant([1., 1.], shape=[2], name="x")
      y = constant_op.constant([2., 2.], shape=[2], name="y")
      z = math_ops.add(x, y, name="z")

    sgv = ge.sgv(x.op, y.op, z.op)
    ge.connect(sgv, ge.sgv(self.e.op).remap_inputs([0]))
    self.assertTrue(
        match.OpMatcher("^foo/bar/e$").input_ops("^z$", "foo/d$")(self.e.op))
Ejemplo n.º 6
0
  def test_transform(self):
    """Test transform graph."""
    transformer = ge.Transformer()

    def my_transform_op_handler(info, op, new_inputs):
      add_noise = op.name.startswith("Add")
      op_, op_outputs_ = ge.transform.copy_op_handler(info, op, new_inputs)
      if not add_noise:
        return op_, op_outputs_
      # add some noise to op
      with info.graph_.as_default():
        t_ = math_ops.add(
            constant_op.constant(1.0, shape=[10], name="Noise"),
            op_.outputs[0],
            name="AddNoise")
      # return the "noisy" op
      return op_, [t_]

    transformer.transform_op_handler = my_transform_op_handler

    graph = ops.Graph()
    transformer(self.graph, graph, "", "")
    matcher0 = match.OpMatcher("AddNoise").input_ops(
        "Noise",
        match.OpMatcher("Add").input_ops("Const", "Input"))
    matcher1 = match.OpMatcher("AddNoise_1").input_ops(
        "Noise_1",
        match.OpMatcher("Add_1").input_ops("Const_1", matcher0))
    matcher2 = match.OpMatcher("AddNoise_2").input_ops(
        "Noise_2",
        match.OpMatcher("Add_2").input_ops("Const_2", matcher1))
    top = ge.select_ops("^AddNoise_2$", graph=graph)[0]
    self.assertTrue(matcher2(top))
Ejemplo n.º 7
0
    def test_reroute_can_modify(self):
        """Test rerout can modify."""
        graph = ops.Graph()
        # create a special graph where "a" is an ambiguous tensor. That is
        # it is both an input and an output of the ops in sgv0.
        with graph.as_default():
            a = constant_op.constant(1.0, shape=[2], name="a")
            b = constant_op.constant(2.0, shape=[2], name="b")
            c = math_ops.add(a, b, name="c")
            d = math_ops.add(a, c, name="d")

            e = constant_op.constant(1.0, shape=[2], name="e")
            f = constant_op.constant(2.0, shape=[2], name="f")
            g = math_ops.add(e, f, name="g")

        sgv0 = ge.sgv(a.op, b.op, c.op)
        sgv1 = ge.sgv(e.op, f.op)

        ge.swap_outputs(sgv0, sgv1)
        self.assertTrue(
            match.OpMatcher("g").input_ops(
                "a",
                match.OpMatcher("c").input_ops("a", "b"))(g.op))
        self.assertTrue(match.OpMatcher("d").input_ops("e", "f")(d.op))
Ejemplo n.º 8
0
 def test_simple_match(self):
   """Test simple match."""
   self.assertTrue(match.OpMatcher("^.*/f$")(self.f.op))
   self.assertTrue(
       match.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$")(self.f.op))
   self.assertTrue(
       match.OpMatcher("^.*/f$").input_ops(True, "^.*/d$")(self.f.op))
   self.assertTrue(
       match.OpMatcher("^.*/f$").input_ops(
           match.op_type("Add"), match.op_type("Const"))(self.f.op))
   self.assertTrue(
       match.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$").output_ops(
           match.OpMatcher("^.*/h$").control_input_ops("^.*/c$"))(self.f.op))
   self.assertTrue(
       match.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$").output_ops(
           match.OpMatcher("^.*/h$").control_input_ops("^.*/c$").output_ops(
               []))(self.f.op))
Ejemplo n.º 9
0
 def test_bypass(self):
   """Test for ge.bypass."""
   ge.bypass(ge.sgv(self.f.op).remap_inputs([0]))
   self.assertTrue(
       match.OpMatcher("^foo/bar/h$").input_ops("^foo/c$",
                                                "foo/bar/g$")(self.h.op))