def test_graph_replace_gradients(self): """Test replace gradients.""" ops.reset_default_graph() w = variables.VariableV1(0.0, name="w") y = math_ops.multiply(math_ops.multiply(w, w, name="mul1"), w, name="mul2") g = tf.gradients(y, w, name="grad")[0] # Extract the operations. replacement_ts = {w.op: g} # replacement_ts = {w.value(): g} original_mul1_grad = ( ops.get_default_graph().get_operation_by_name("grad/mul1_grad/Mul_1")) # Should not raise exception. res = ge.graph_replace(g, replacement_ts, dst_scope="res") # Extract the operations after graph_replace. result_mul1_grad = ( ops.get_default_graph().get_operation_by_name( "res/grad/mul1_grad/Mul_1")) # Make sure _original_ops are as expected. self.assertEqual(original_mul1_grad._original_op.name, u"mul1") self.assertEqual(result_mul1_grad._original_op.name, u"res/mul1") self.assertNotEqual(res.name, g.name) with session.Session() as sess: sess.run(variables.global_variables_initializer()) g_val, res_val = sess.run([g, res]) self.assertNear(g_val, 0.0, ERROR_TOLERANCE) self.assertNear(res_val, 0.0, ERROR_TOLERANCE)
def test_graph_replace_missing(self): """Test replace missing.""" ops.reset_default_graph() a = constant_op.constant(1.0, name="a") b = constant_op.constant(2.0, name="b") c = a + 2 * b d = constant_op.constant(2.0, name="d") res = ge.graph_replace([b, c], {a: d}) self.assertEqual(res[0].name, "b:0") self.assertEqual(res[1].name, "add_1:0")
def test_graph_replace_ordered_dict(self): """Test replace graph with ord dict.""" ops.reset_default_graph() a = constant_op.constant(1.0, name="a") b = variables.Variable(1.0, name="b") eps = constant_op.constant(0.001, name="eps") c = array_ops.identity(a + b + eps, name="c") a_new = constant_op.constant(2.0, name="a_new") c_new = ge.graph_replace(collections.OrderedDict({"c": c}), {a: a_new}) self.assertIsInstance(c_new, collections.OrderedDict)
def test_graph_replace_named_tuple(self): """Test replace graph with named tuple.""" ops.reset_default_graph() a = constant_op.constant(1.0, name="a") b = variables.Variable(1.0, name="b") eps = constant_op.constant(0.001, name="eps") c = array_ops.identity(a + b + eps, name="c") a_new = constant_op.constant(2.0, name="a_new") one_tensor = collections.namedtuple("OneTensor", ["t"]) c_new = ge.graph_replace(one_tensor(c), {a: a_new}) self.assertIsInstance(c_new, one_tensor)
def test_graph_replace(self): """Test replace graph.""" ops.reset_default_graph() a = constant_op.constant(1.0, name="a") b = variables.Variable(1.0, name="b") eps = constant_op.constant(0.001, name="eps") c = array_ops.identity(a + b + eps, name="c") a_new = constant_op.constant(2.0, name="a_new") c_new = ge.graph_replace(c, {a: a_new}) with session.Session() as sess: sess.run(variables.global_variables_initializer()) c_val, c_new_val = sess.run([c, c_new]) self.assertNear(c_val, 2.001, ERROR_TOLERANCE) self.assertNear(c_new_val, 3.001, ERROR_TOLERANCE)