示例#1
0
    def test_graph_replace_gradients(self):
        ops.reset_default_graph()
        w = variables.VariableV1(0.0, name="w")
        y = math_ops.multiply(math_ops.multiply(w, w, name="mul1"),
                              w,
                              name="mul2")
        g = gradients_impl.gradients(y, w, name="grad")[0]

        # Extract the operations.
        replacement_ts = {w.value(): g}
        original_mul1_grad = (ops.get_default_graph().get_operation_by_name(
            "grad/mul1_grad/Mul_1"))

        # Should not raise exception.
        res = ge.graph_replace(g, replacement_ts, dst_scope="res")

        # Extract the operations after graph_replace.
        result_mul1_grad = (ops.get_default_graph().get_operation_by_name(
            "res/grad/mul1_grad/Mul_1"))

        # Make sure _original_ops are as expected.
        self.assertEqual(original_mul1_grad._original_op.name, u"mul1")
        self.assertEqual(result_mul1_grad._original_op.name, u"res/mul1")
        self.assertNotEqual(res.name, g.name)
        with session.Session() as sess:
            sess.run(variables.global_variables_initializer())
            g_val, res_val = sess.run([g, res])
        self.assertNear(g_val, 0.0, ERROR_TOLERANCE)
        self.assertNear(res_val, 0.0, ERROR_TOLERANCE)
示例#2
0
 def test_graph_replace_missing(self):
     ops.reset_default_graph()
     a = constant_op.constant(1.0, name="a")
     b = constant_op.constant(2.0, name="b")
     c = a + 2 * b
     d = constant_op.constant(2.0, name="d")
     res = ge.graph_replace([b, c], {a: d})
     self.assertEqual(res[0].name, "b:0")
     self.assertEqual(res[1].name, "add_1:0")
示例#3
0
 def test_graph_replace_ordered_dict(self):
     ops.reset_default_graph()
     a = constant_op.constant(1.0, name="a")
     b = variables.Variable(1.0, name="b")
     eps = constant_op.constant(0.001, name="eps")
     c = array_ops.identity(a + b + eps, name="c")
     a_new = constant_op.constant(2.0, name="a_new")
     c_new = ge.graph_replace(collections.OrderedDict({"c": c}), {a: a_new})
     self.assertTrue(isinstance(c_new, collections.OrderedDict))
示例#4
0
 def test_graph_replace_named_tuple(self):
     ops.reset_default_graph()
     a = constant_op.constant(1.0, name="a")
     b = variables.Variable(1.0, name="b")
     eps = constant_op.constant(0.001, name="eps")
     c = array_ops.identity(a + b + eps, name="c")
     a_new = constant_op.constant(2.0, name="a_new")
     one_tensor = collections.namedtuple("OneTensor", ["t"])
     c_new = ge.graph_replace(one_tensor(c), {a: a_new})
     self.assertTrue(isinstance(c_new, one_tensor))
示例#5
0
 def test_graph_replace(self):
     ops.reset_default_graph()
     a = constant_op.constant(1.0, name="a")
     b = variables.Variable(1.0, name="b")
     eps = constant_op.constant(0.001, name="eps")
     c = array_ops.identity(a + b + eps, name="c")
     a_new = constant_op.constant(2.0, name="a_new")
     c_new = ge.graph_replace(c, {a: a_new})
     with session.Session() as sess:
         sess.run(variables.global_variables_initializer())
         c_val, c_new_val = sess.run([c, c_new])
     self.assertNear(c_val, 2.001, ERROR_TOLERANCE)
     self.assertNear(c_new_val, 3.001, ERROR_TOLERANCE)