Пример #1
0
  def test_graph_replace_gradients(self):
    """Test replace gradients."""
    ops.reset_default_graph()
    w = variables.VariableV1(0.0, name="w")
    y = math_ops.multiply(math_ops.multiply(w, w, name="mul1"), w, name="mul2")
    g = tf.gradients(y, w, name="grad")[0]

    # Extract the operations.
    replacement_ts = {w.op: g}
    # replacement_ts = {w.value(): g}
    original_mul1_grad = (
        ops.get_default_graph().get_operation_by_name("grad/mul1_grad/Mul_1"))

    # Should not raise exception.
    res = ge.graph_replace(g, replacement_ts, dst_scope="res")

    # Extract the operations after graph_replace.
    result_mul1_grad = (
        ops.get_default_graph().get_operation_by_name(
            "res/grad/mul1_grad/Mul_1"))

    # Make sure _original_ops are as expected.
    self.assertEqual(original_mul1_grad._original_op.name, u"mul1")
    self.assertEqual(result_mul1_grad._original_op.name, u"res/mul1")
    self.assertNotEqual(res.name, g.name)
    with session.Session() as sess:
      sess.run(variables.global_variables_initializer())

      g_val, res_val = sess.run([g, res])
    self.assertNear(g_val, 0.0, ERROR_TOLERANCE)
    self.assertNear(res_val, 0.0, ERROR_TOLERANCE)
Пример #2
0
 def test_graph_replace_missing(self):
   """Test replace missing."""
   ops.reset_default_graph()
   a = constant_op.constant(1.0, name="a")
   b = constant_op.constant(2.0, name="b")
   c = a + 2 * b
   d = constant_op.constant(2.0, name="d")
   res = ge.graph_replace([b, c], {a: d})
   self.assertEqual(res[0].name, "b:0")
   self.assertEqual(res[1].name, "add_1:0")
Пример #3
0
 def test_graph_replace_ordered_dict(self):
   """Test replace graph with ord dict."""
   ops.reset_default_graph()
   a = constant_op.constant(1.0, name="a")
   b = variables.Variable(1.0, name="b")
   eps = constant_op.constant(0.001, name="eps")
   c = array_ops.identity(a + b + eps, name="c")
   a_new = constant_op.constant(2.0, name="a_new")
   c_new = ge.graph_replace(collections.OrderedDict({"c": c}), {a: a_new})
   self.assertIsInstance(c_new, collections.OrderedDict)
Пример #4
0
 def test_graph_replace_named_tuple(self):
   """Test replace graph with named tuple."""
   ops.reset_default_graph()
   a = constant_op.constant(1.0, name="a")
   b = variables.Variable(1.0, name="b")
   eps = constant_op.constant(0.001, name="eps")
   c = array_ops.identity(a + b + eps, name="c")
   a_new = constant_op.constant(2.0, name="a_new")
   one_tensor = collections.namedtuple("OneTensor", ["t"])
   c_new = ge.graph_replace(one_tensor(c), {a: a_new})
   self.assertIsInstance(c_new, one_tensor)
Пример #5
0
 def test_graph_replace(self):
   """Test replace graph."""
   ops.reset_default_graph()
   a = constant_op.constant(1.0, name="a")
   b = variables.Variable(1.0, name="b")
   eps = constant_op.constant(0.001, name="eps")
   c = array_ops.identity(a + b + eps, name="c")
   a_new = constant_op.constant(2.0, name="a_new")
   c_new = ge.graph_replace(c, {a: a_new})
   with session.Session() as sess:
     sess.run(variables.global_variables_initializer())
     c_val, c_new_val = sess.run([c, c_new])
   self.assertNear(c_val, 2.001, ERROR_TOLERANCE)
   self.assertNear(c_new_val, 3.001, ERROR_TOLERANCE)