def testDeviceBeforeCond(self): with ops.Graph().as_default() as g: with self.session(graph=g): def fn(): self.assertEqual("", constant_op.constant(3.0).op.device) return test_ops.device_placement_op() with ops.device("/device:CPU:0"): self.assertIn( compat.as_bytes("CPU:0"), self.evaluate(cond_v2.cond_v2(constant_op.constant(True), fn, fn))) def fn2(): self.assertEqual("", constant_op.constant(3.0).op.device) return test_ops.device_placement_op() if test_util.is_gpu_available(): with ops.device("/device:GPU:0"): self.assertIn( compat.as_bytes("GPU:0"), self.evaluate(cond_v2.cond_v2(constant_op.constant(True), fn2, fn2))) else: self.skipTest("Test requires a GPU to check GPU device placement.")
def testColocateWithBeforeCond(self): with ops.Graph().as_default() as g: with self.session(graph=g): a = constant_op.constant([2.0], name="a") b = constant_op.constant([2.0], name="b") def fn(): c = constant_op.constant(3.0) self.assertEqual([b"loc:@a"], c.op.colocation_groups()) return c with ops.colocate_with(a.op): self.assertEquals( cond_v2.cond_v2(constant_op.constant(True), fn, fn).eval(), 3) def fn2(): c = constant_op.constant(3.0) self.assertEqual([b"loc:@a", b"loc:@b"], c.op.colocation_groups()) return c with ops.colocate_with(a.op): with ops.colocate_with(b.op): self.assertEquals( cond_v2.cond_v2(constant_op.constant(True), fn2, fn2).eval(), 3)
def testExternalControlDependencies(self): with ops.Graph().as_default(), self.test_session(): v = variables.Variable(1.0) v.initializer.run() op = v.assign_add(1.0) def true_branch(): with ops.control_dependencies([op]): return 1.0 cond_v2.cond_v2(array_ops.placeholder_with_default(False, None), true_branch, lambda: 2.0).eval() self.assertAllEqual(self.evaluate(v), 2.0)
def func_with_cond(): pred = constant_op.constant(True, name="pred") x = constant_op.constant(1.0, name="x") def true_fn(): intermediate = x + 1 return intermediate * x def false_fn(): return x + 1 output = cond_v2.cond_v2(pred, true_fn, false_fn) grad = gradients_impl.gradients(output, x)[0] forward_if_op = output.op.inputs[0].op gradient_if_op = grad.op.inputs[0].op def verify_no_optional_ops(op, branch_name): branch_function = ops.get_default_graph()._get_function( op.get_attr(branch_name).name) function_def = branch_function.definition for node_def in function_def.node_def: self.assertNotIn(node_def.op, _OPTIONAL_OPS) verify_no_optional_ops(forward_if_op, "then_branch") verify_no_optional_ops(forward_if_op, "else_branch") verify_no_optional_ops(gradient_if_op, "then_branch") verify_no_optional_ops(gradient_if_op, "else_branch") return grad
def testSecondDerivative(self): with self.test_session() as sess: pred = array_ops.placeholder(dtypes.bool, name="pred") x = constant_op.constant(3.0, name="x") def true_fn(): return math_ops.pow(x, 3) def false_fn(): return x cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond") cond_grad = gradients_impl.gradients(cond, [x]) cond_grad_grad = gradients_impl.gradients(cond_grad, [x]) # d[x^3]/dx = 3x^2 true_val = sess.run(cond_grad, {pred: True}) self.assertEqual(true_val, [27.0]) # d[x]/dx = 1 false_val = sess.run(cond_grad, {pred: False}) self.assertEqual(false_val, [1.0]) true_val = sess.run(cond_grad_grad, {pred: True}) # d2[x^3]/dx2 = 6x self.assertEqual(true_val, [18.0]) false_val = sess.run(cond_grad_grad, {pred: False}) # d2[x]/dx2 = 0 self.assertEqual(false_val, [0.0])
def _testCond(self, true_fn, false_fn, train_vals, feed_dict=None): if not feed_dict: feed_dict = {} with self.test_session(graph=ops.get_default_graph()) as sess: pred = array_ops.placeholder(dtypes.bool, name="pred") expected = control_flow_ops.cond(pred, true_fn, false_fn, name="expected") actual = cond_v2.cond_v2(pred, true_fn, false_fn, name="actual") expected_grad = gradients_impl.gradients(expected, train_vals) actual_grad = gradients_impl.gradients(actual, train_vals) sess_run_args = {pred: True} sess_run_args.update(feed_dict) expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( (expected, actual, expected_grad, actual_grad), sess_run_args) self.assertEqual(expected_val, actual_val) self.assertEqual(expected_grad_val, actual_grad_val) sess_run_args = {pred: False} sess_run_args.update(feed_dict) expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( (expected, actual, expected_grad, actual_grad), sess_run_args) self.assertEqual(expected_val, actual_val) self.assertEqual(expected_grad_val, actual_grad_val)
def build_graph(): pred_outer = array_ops.placeholder(dtypes.bool, name="pred_outer") pred_inner = array_ops.placeholder(dtypes.bool, name="pred_inner") x = constant_op.constant(1.0, name="x") y = constant_op.constant(2.0, name="y") def true_fn(): return 2.0 def false_fn(): def inner_true_fn(): return x * y * 2.0 def inner_false_fn(): return x * 5.0 return cond_v2.cond_v2( pred_inner, inner_true_fn, inner_false_fn, name="inner_cond") cond_outer = cond_v2.cond_v2( pred_outer, true_fn, false_fn, name="outer_cond") # Compute grads inside a Defun. @function.defun def nesting_fn(): return gradients_impl.gradients(cond_outer, [x, y]) grads = nesting_fn() return grads, pred_outer, pred_inner
def testColocateWithInCondGraphPartitioning(self): with ops.Graph().as_default() as g: with self.test_session( graph=g, config=config_pb2.ConfigProto(device_count={"CPU": 2}) ) as sess: with ops.device("/device:CPU:0"): a = constant_op.constant([2.0], name="a") with ops.device("/device:CPU:1"): b = constant_op.constant([2.0], name="b") def fn(): with ops.colocate_with(b.op): c = math_ops.add(a, a, name="c") return c out_cond_2 = cond_v2.cond_v2(True, fn, fn)[0] run_options = config_pb2.RunOptions(output_partition_graphs=True) run_metadata = config_pb2.RunMetadata() sess.run(out_cond_2, options=run_options, run_metadata=run_metadata) # We expect there to be two partitions because of the # colocate_with. We are only running the cond, which has a data # dependency on `a` but not on `b`. So, without the colocate_with # we would expect execution on just one device. self.assertTrue(len(run_metadata.partition_graphs) >= 2)
def testDeviceBeforeCond(self): with ops.Graph().as_default() as g: with self.test_session(graph=g): def fn(): c = constant_op.constant(3.0) self.assertEqual("/device:CPU:0", c.op.device) return c with ops.device("/device:CPU:0"): self.assertEquals(cond_v2.cond_v2(True, fn, fn)[0].eval(), 3) def fn2(): c = constant_op.constant(3.0) self.assertEqual("/device:GPU:0", c.op.device) return c with ops.device("/device:GPU:0"): self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
def false_fn(): def inner_true_fn(): return x * y * 2.0 def inner_false_fn(): return x * 5.0 return cond_v2.cond_v2( pred_inner, inner_true_fn, inner_false_fn, name="inner_cond")
def _createCond(self, name): pred = constant_op.constant(True, name="pred") x = constant_op.constant(1.0, name="x") def true_fn(): return x def false_fn(): return x + 1 return cond_v2.cond_v2(pred, true_fn, false_fn, name=name)[0].op
def fnWithCond(): # pylint: disable=invalid-name with backprop.GradientTape() as tape: pred = constant_op.constant(True, dtype=dtypes.bool) def true_fn(): return math_ops.pow(v, 3) def false_fn(): return v cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond") return tape.gradient(cond, v)
def testNoInputs(self): with self.test_session() as sess: pred = array_ops.placeholder(dtypes.bool, name="pred") def true_fn(): return constant_op.constant(1.0) def false_fn(): return constant_op.constant(2.0) out = cond_v2.cond_v2(pred, true_fn, false_fn) self.assertEqual(sess.run(out, {pred: True}), (1.0,)) self.assertEqual(sess.run(out, {pred: False}), (2.0,))
def _createCond(self, name): """Helper function for testDefaultName.""" pred = constant_op.constant(True, name="pred") x = constant_op.constant(1.0, name="x") def true_fn(): return x def false_fn(): return x + 1 output = cond_v2.cond_v2(pred, true_fn, false_fn, name=name) cond_op = output.op.inputs[0].op self.assertEqual(cond_op.type, "If") return cond_op
def testCollectionIntValueAccessInCond(self): """Read values from graph collections inside of cond_v2.""" with ops.Graph().as_default() as g: with self.test_session(graph=g): x = 2 y = 5 ops.add_to_collection("x", x) ops.add_to_collection("y", y) def fn(): x_const = constant_op.constant(ops.get_collection("x")[0]) y_const = constant_op.constant(ops.get_collection("y")[0]) return math_ops.add(x_const, y_const) cnd = cond_v2.cond_v2(True, fn, fn) self.assertEquals(cnd[0].eval(), 7)
def _createNestedCond(self, name): """Like _createCond but creates a nested cond_v2 call as well.""" pred = constant_op.constant(True, name="pred") x = constant_op.constant(1.0, name="x") def true_fn(): return cond_v2.cond_v2(pred, lambda: x, lambda: x + 1) def false_fn(): return x + 2 output = cond_v2.cond_v2(pred, true_fn, false_fn, name=name) cond_op = output.op.inputs[0].op self.assertEqual(cond_op.type, "If") return output, cond_op
def _createCond(self, name): """Creates a cond_v2 call and returns the output tensor and the cond op.""" pred = constant_op.constant(True, name="pred") x = constant_op.constant(1.0, name="x") def true_fn(): return x def false_fn(): return x + 1 output = cond_v2.cond_v2(pred, true_fn, false_fn, name=name) cond_op = output.op.inputs[0].op self.assertEqual(cond_op.type, "If") return output, cond_op
def testCollectionIntValueAccessInCond(self): """Read values from graph collections inside of cond_v2.""" with ops.Graph().as_default() as g: with self.session(graph=g): x = 2 y = 5 ops.add_to_collection("x", x) ops.add_to_collection("y", y) def fn(): x_const = constant_op.constant(ops.get_collection("x")[0]) y_const = constant_op.constant(ops.get_collection("y")[0]) return math_ops.add(x_const, y_const) cnd = cond_v2.cond_v2(constant_op.constant(True), fn, fn) self.assertEquals(cnd.eval(), 7)
def testCollectionTensorValueAccessInCond(self): """Read tensors from collections inside of cond_v2 & use them.""" with ops.Graph().as_default() as g: with self.session(graph=g): x = constant_op.constant(2) y = constant_op.constant(5) ops.add_to_collection("x", x) ops.add_to_collection("y", y) def fn(): x_read = ops.get_collection("x")[0] y_read = ops.get_collection("y")[0] return math_ops.add(x_read, y_read) cnd = cond_v2.cond_v2(math_ops.less(x, y), fn, fn) self.assertEquals(cnd.eval(), 7)
def testForwardPassRewrite(self): x = constant_op.constant(1.0, name="x") output = cond_v2.cond_v2(constant_op.constant(True), lambda: x * 2.0, lambda: x) if_op = output.op.inputs[0].op self.assertEqual(if_op.type, "If") # pylint: disable=g-deprecated-assert self.assertEqual(len(if_op.outputs), 1) gradients_impl.gradients(output, x) # if_op should have been rewritten to output 2.0 intermediate. self.assertEqual(len(if_op.outputs), 2) gradients_impl.gradients(output, x) # Computing the gradient again shouldn't rewrite if_op again. self.assertEqual(len(if_op.outputs), 2)
def testDeviceInAndOutOfCond(self): with ops.Graph().as_default() as g: with self.test_session( graph=g, config=config_pb2.ConfigProto(device_count={"CPU": 2})): def fn2(): with ops.device("/device:CPU:1"): c = constant_op.constant(3.0) self.assertEqual("/device:CPU:1", c.op.device) return c with ops.device("/device:CPU:0"): self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3) d = constant_op.constant(4.0) self.assertEqual("/device:CPU:0", d.op.device)
def testDeviceInAndOutOfCond(self): with ops.Graph().as_default() as g: with self.test_session(graph=g): def fn2(): with ops.device("/device:GPU:0"): c = constant_op.constant(3.0) self.assertEqual("/device:GPU:0", c.op.device) return c with ops.device("/device:CPU:0"): self.assertEquals( cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3) d = constant_op.constant(4.0) self.assertEqual("/device:CPU:0", d.op.device)
def testCollectionTensorValueAccessInCond(self): """Read tensors from collections inside of cond_v2 & use them.""" with ops.Graph().as_default() as g: with self.test_session(graph=g): x = constant_op.constant(2) y = constant_op.constant(5) ops.add_to_collection("x", x) ops.add_to_collection("y", y) def fn(): x_read = ops.get_collection("x")[0] y_read = ops.get_collection("y")[0] return math_ops.add(x_read, y_read) cnd = cond_v2.cond_v2(math_ops.less(x, y), fn, fn) self.assertEquals(cnd[0].eval(), 7)
def testDeviceInAndOutOfCond(self): with ops.Graph().as_default() as g: with self.session( graph=g, config=config_pb2.ConfigProto(device_count={"CPU": 2})): def fn2(): with ops.device("/device:CPU:1"): c = constant_op.constant(3.0) self.assertEqual("/device:CPU:1", c.op.device) return c with ops.device("/device:CPU:0"): self.assertEquals( cond_v2.cond_v2(constant_op.constant(True), fn2, fn2).eval(), 3) d = constant_op.constant(4.0) self.assertEqual("/device:CPU:0", d.op.device)
def testForwardPassRewrite(self): x = constant_op.constant(1.0, name="x") output = cond_v2.cond_v2(constant_op.constant(True), lambda: x * 2.0, lambda: x) if_op = output.op.inputs[0].op self.assertEqual(if_op.type, "If") # pylint: disable=g-deprecated-assert self.assertEqual(len(if_op.outputs), 1) gradients_impl.gradients(output, x) # if_op should have been rewritten to output 2.0 intermediate. self.assertEqual(len(if_op.outputs), 2) gradients_impl.gradients(output, x) # Computing the gradient again shouldn't rewrite if_op again. self.assertEqual(len(if_op.outputs), 2)
def testDoNotAccumulateConstants(self): x = constant_op.constant(1.0, name="x") output = cond_v2.cond_v2( constant_op.constant(True), lambda: x * 2.0, lambda: x) if_op = output.op.inputs[0].op self.assertEqual(if_op.type, "StatelessIf") # pylint: disable=g-deprecated-assert self.assertEqual(len(if_op.outputs), 1) gradients_impl.gradients(output, x) # Number of outputs does change because # 1. `x` is a loop input so does not need to be accumulated. # 2. 2.0 is a constant so it is not accumulated. self.assertEqual(len(if_op.outputs), 1) gradients_impl.gradients(output, x) # Computing the gradient again shouldn't rewrite if_op again. self.assertEqual(len(if_op.outputs), 1)
def testColocateWithInAndOutOfCond(self): with ops.Graph().as_default() as g: with self.test_session(graph=g): a = constant_op.constant([2.0], name="a") b = constant_op.constant([2.0], name="b") def fn2(): with ops.colocate_with(b.op): c = constant_op.constant(3.0) self.assertEqual([b"loc:@a", b"loc:@b"], c.op.colocation_groups()) return c with ops.colocate_with(a.op): self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3) d = constant_op.constant([2.0], name="d") self.assertEqual([b"loc:@a"], d.op.colocation_groups())
def fn(): def true_fn(): return 2.0 def false_fn(): def inner_true_fn(): return x * y * 2.0 def inner_false_fn(): return x * 5.0 return cond_v2.cond_v2( pred_inner, inner_true_fn, inner_false_fn, name="inner_cond") cond_outer = cond_v2.cond_v2( pred_outer, true_fn, false_fn, name="outer_cond") return gradients_impl.gradients(cond_outer, [x, y])
def fn(): def true_fn(): return 2.0 def false_fn(): def inner_true_fn(): return x * y * 2.0 def inner_false_fn(): return x * 5.0 return cond_v2.cond_v2( pred_inner, inner_true_fn, inner_false_fn, name="inner_cond") cond_outer = cond_v2.cond_v2( pred_outer, true_fn, false_fn, name="outer_cond") return gradients_impl.gradients(cond_outer, [x, y])
def _testCond(self, true_fn, false_fn, train_vals): with self.test_session() as sess: pred = array_ops.placeholder(dtypes.bool, name="pred") expected = control_flow_ops.cond(pred, true_fn, false_fn, name="expected") actual = cond_v2.cond_v2(pred, true_fn, false_fn, name="actual") expected_grad = gradients_impl.gradients(expected, train_vals) actual_grad = gradients_impl.gradients(actual, train_vals) expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( (expected, actual, expected_grad, actual_grad), {pred: True}) self.assertEqual(expected_val, actual_val) self.assertEqual(expected_grad_val, actual_grad_val) expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( (expected, actual, expected_grad, actual_grad), {pred: False}) self.assertEqual(expected_val, actual_val) self.assertEqual(expected_grad_val, actual_grad_val)
def testCollectionIntValueWriteInCond(self): """Make sure Int writes to collections work inside of cond_v2.""" with ops.Graph().as_default() as g: with self.session(graph=g): x = constant_op.constant(2) y = constant_op.constant(5) def true_fn(): z = math_ops.add(x, y) ops.add_to_collection("z", 7) return math_ops.mul(x, z) def false_fn(): z = math_ops.add(x, y) return math_ops.mul(x, z) cnd = cond_v2.cond_v2(constant_op.constant(True), true_fn, false_fn) self.assertEquals(cnd.eval(), 14) read_z_collection = ops.get_collection("z") self.assertEquals(read_z_collection, [7])
def testCollectionIntValueWriteInCond(self): """Make sure Int writes to collections work inside of cond_v2.""" with ops.Graph().as_default() as g: with self.session(graph=g): x = constant_op.constant(2) y = constant_op.constant(5) def true_fn(): z = math_ops.add(x, y) ops.add_to_collection("z", 7) return math_ops.mul(x, z) def false_fn(): z = math_ops.add(x, y) return math_ops.mul(x, z) cnd = cond_v2.cond_v2(constant_op.constant(True), true_fn, false_fn) self.assertEquals(cnd.eval(), 14) read_z_collection = ops.get_collection("z") self.assertEquals(read_z_collection, [7])
def testGradientOfDeserializedCond(self): with ops.Graph().as_default(): pred = array_ops.placeholder(dtypes.bool, name="pred") x = constant_op.constant(3.0, name="x") ops.add_to_collection("x", x) def true_fn(): return math_ops.pow(x, 3) def false_fn(): return x ops.add_to_collection("pred", pred) cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond") for c in cond: ops.add_to_collection("cond", c) meta_graph = saver.export_meta_graph() with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: saver.import_meta_graph(meta_graph) x = ops.get_collection("x")[0] pred = ops.get_collection("pred")[0] cond = ops.get_collection("cond") cond_grad = gradients_impl.gradients(cond, [x], name="cond_grad") cond_grad_grad = gradients_impl.gradients( cond_grad, [x], name="cond_grad_grad") # d[x^3]/dx = 3x^2 true_val = sess.run(cond_grad, {pred: True}) self.assertEqual(true_val, [27.0]) # d[x]/dx = 1 false_val = sess.run(cond_grad, {pred: False}) self.assertEqual(false_val, [1.0]) true_val = sess.run(cond_grad_grad, {pred: True}) # d2[x^3]/dx2 = 6x self.assertEqual(true_val, [18.0]) false_val = sess.run(cond_grad_grad, {pred: False}) # d2[x]/dx2 = 0 self.assertEqual(false_val, [0.0])
def testForwardPassRewrite(self): x = constant_op.constant(1.0, name="x") y = constant_op.constant(1.0, name="y") def true_fn(): y_plus_one = y + 1. return x * y_plus_one output = cond_v2.cond_v2(constant_op.constant(True), true_fn, lambda: x) if_op = output.op.inputs[0].op self.assertEqual(if_op.type, "StatelessIf") # pylint: disable=g-deprecated-assert self.assertEqual(len(if_op.outputs), 1) gradients_impl.gradients(output, x) # if_op should have been rewritten to output `y_plus_one`. self.assertEqual(len(if_op.outputs), 2) gradients_impl.gradients(output, x) # Computing the gradient again shouldn't rewrite if_op again. self.assertEqual(len(if_op.outputs), 2)
def testDeviceInCondGraphPartitioning(self): with ops.Graph().as_default() as g: with self.test_session( graph=g, config=config_pb2.ConfigProto(device_count={"CPU": 2}) ) as sess: def fn(): with ops.device("/device:CPU:1"): c = math_ops.add(a, a, name="c") return c with ops.device("/device:CPU:0"): a = constant_op.constant([2.0], name="a") out_cond_2 = cond_v2.cond_v2(True, fn, fn)[0] run_options = config_pb2.RunOptions(output_partition_graphs=True) run_metadata = config_pb2.RunMetadata() sess.run(out_cond_2, options=run_options, run_metadata=run_metadata) self.assertTrue(len(run_metadata.partition_graphs) >= 2)
def testDeviceInCondGraphPartitioning(self): with ops.Graph().as_default() as g: with self.session( graph=g, config=config_pb2.ConfigProto(device_count={"CPU": 2}) ) as sess: def fn(): with ops.device("/device:CPU:1"): c = math_ops.add(a, a, name="c") return c with ops.device("/device:CPU:0"): a = constant_op.constant([2.0], name="a") out_cond_2 = cond_v2.cond_v2(constant_op.constant(True), fn, fn) run_options = config_pb2.RunOptions(output_partition_graphs=True) run_metadata = config_pb2.RunMetadata() sess.run(out_cond_2, options=run_options, run_metadata=run_metadata) self.assertTrue(len(run_metadata.partition_graphs) >= 2)
def _testCond(self, true_fn, false_fn, train_vals, feed_dict=None): if not feed_dict: feed_dict = {} with self.session(graph=ops.get_default_graph()) as sess: pred = array_ops.placeholder(dtypes.bool, name="pred") expected = control_flow_ops.cond( array_ops.squeeze_v2(pred), true_fn, false_fn, name="expected") actual = cond_v2.cond_v2(pred, true_fn, false_fn, name="actual") expected_grad = gradients_impl.gradients(expected, train_vals) actual_grad = gradients_impl.gradients(actual, train_vals) sess_run_args = {pred: True} sess_run_args.update(feed_dict) expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( (expected, actual, expected_grad, actual_grad), sess_run_args) self.assertEqual(expected_val, actual_val) self.assertEqual(expected_grad_val, actual_grad_val) sess_run_args = {pred: [[True]]} sess_run_args.update(feed_dict) expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( (expected, actual, expected_grad, actual_grad), sess_run_args) self.assertEqual(expected_val, actual_val) self.assertEqual(expected_grad_val, actual_grad_val) sess_run_args = {pred: False} sess_run_args.update(feed_dict) expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( (expected, actual, expected_grad, actual_grad), sess_run_args) self.assertEqual(expected_val, actual_val) self.assertEqual(expected_grad_val, actual_grad_val) sess_run_args = {pred: [[False]]} sess_run_args.update(feed_dict) expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( (expected, actual, expected_grad, actual_grad), sess_run_args) self.assertEqual(expected_val, actual_val) self.assertEqual(expected_grad_val, actual_grad_val)
def testGradientOfDeserializedCond(self): with ops.Graph().as_default(): pred = array_ops.placeholder(dtypes.bool, name="pred") x = constant_op.constant(3.0, name="x") ops.add_to_collection("x", x) def true_fn(): return math_ops.pow(x, 3) def false_fn(): return x ops.add_to_collection("pred", pred) cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond") for c in cond: ops.add_to_collection("cond", c) meta_graph = saver.export_meta_graph() with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: saver.import_meta_graph(meta_graph) x = ops.get_collection("x")[0] pred = ops.get_collection("pred")[0] cond = ops.get_collection("cond") cond_grad = gradients_impl.gradients(cond, [x], name="cond_grad") cond_grad_grad = gradients_impl.gradients( cond_grad, [x], name="cond_grad_grad") # d[x^3]/dx = 3x^2 true_val = sess.run(cond_grad, {pred: True}) self.assertEqual(true_val, [27.0]) # d[x]/dx = 1 false_val = sess.run(cond_grad, {pred: False}) self.assertEqual(false_val, [1.0]) true_val = sess.run(cond_grad_grad, {pred: True}) # d2[x^3]/dx2 = 6x self.assertEqual(true_val, [18.0]) false_val = sess.run(cond_grad_grad, {pred: False}) # d2[x]/dx2 = 0 self.assertEqual(false_val, [0.0])
def build_graph(): pred_outer = array_ops.placeholder(dtypes.bool, name="pred_outer") pred_inner = array_ops.placeholder(dtypes.bool, name="pred_inner") x = constant_op.constant(1.0, name="x") y = constant_op.constant(2.0, name="y") def true_fn(): return 2.0 def false_fn(): def inner_true_fn(): return x * y * 2.0 def inner_false_fn(): return x * 5.0 return cond_v2.cond_v2(pred_inner, inner_true_fn, inner_false_fn, name="inner_cond") cond_outer = cond_v2.cond_v2(pred_outer, true_fn, false_fn, name="outer_cond") # Compute grads inside a Defun. @function.defun def nesting_fn(): @function.defun def inner_nesting_fn(): return gradients_impl.gradients(cond_outer, [x, y]) return inner_nesting_fn() grads = nesting_fn() return grads, pred_outer, pred_inner
def fn_with_cond(): def update_v1(): v1.assign(v1) return v1 def update_v2(): v2.assign(v2) return v2 cond_v2.cond_v2( constant_op.constant(True), update_v1, lambda: constant_op.constant(0.), name="cond_1") cond_2 = cond_v2.cond_v2( constant_op.constant(False), lambda: constant_op.constant(0.), update_v1, name="cond_2") cond_v2.cond_v2( constant_op.constant(True), update_v2, lambda: constant_op.constant(0.), name="cond_3") @def_function.function def cond_4_false_branch(): v2.assign(v2) return v2 cond_4 = cond_v2.cond_v2( constant_op.constant(False), lambda: constant_op.constant(0.), cond_4_false_branch, name="cond_4") return cond_2, cond_4
def recursive_fn(n): return cond_v2.cond_v2(n > 0, recursive_fn(n - 1), 1)
def true_fn(): return cond_v2.cond_v2(pred, lambda: x, lambda: x + 1)
def _cond(pred, true_fn, false_fn, name): if _is_old_cond(): return control_flow_ops.cond(pred, true_fn, false_fn, name=name) else: return cond_v2.cond_v2(pred, true_fn, false_fn, name=name)
def _add_cond(x): return cond_v2.cond_v2( constant_op.constant(True, name="pred"), lambda: x, lambda: x + 1)
def true_fn(): return cond_v2.cond_v2(pred, lambda: x, lambda: x + 1)
def testContainer(self): """Set containers outside & inside of cond_v2. Make sure the containers are set correctly for both variable creation (tested by variables.Variable) and for stateful ops (tested by FIFOQueue) """ self.skipTest("b/113048653") with ops.Graph().as_default() as g: with self.test_session(graph=g): v0 = variables.Variable([0]) q0 = data_flow_ops.FIFOQueue(1, dtypes.float32) def container(node): return node.op.get_attr("container") self.assertEqual(compat.as_bytes(""), container(v0)) self.assertEqual(compat.as_bytes(""), container(q0.queue_ref)) def true_fn(): # When this branch is created in cond below, # the container should begin with 'l1' v1 = variables.Variable([1]) q1 = data_flow_ops.FIFOQueue(1, dtypes.float32) with ops.container("l2t"): v2 = variables.Variable([2]) q2 = data_flow_ops.FIFOQueue(1, dtypes.float32) v3 = variables.Variable([1]) q3 = data_flow_ops.FIFOQueue(1, dtypes.float32) self.assertEqual(compat.as_bytes("l1"), container(v1)) self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref)) self.assertEqual(compat.as_bytes("l2t"), container(v2)) self.assertEqual(compat.as_bytes("l2t"), container(q2.queue_ref)) self.assertEqual(compat.as_bytes("l1"), container(v3)) self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref)) return constant_op.constant(2.0) def false_fn(): # When this branch is created in cond below, # the container should begin with 'l1' v1 = variables.Variable([1]) q1 = data_flow_ops.FIFOQueue(1, dtypes.float32) with ops.container("l2f"): v2 = variables.Variable([2]) q2 = data_flow_ops.FIFOQueue(1, dtypes.float32) v3 = variables.Variable([1]) q3 = data_flow_ops.FIFOQueue(1, dtypes.float32) self.assertEqual(compat.as_bytes("l1"), container(v1)) self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref)) self.assertEqual(compat.as_bytes("l2f"), container(v2)) self.assertEqual(compat.as_bytes("l2f"), container(q2.queue_ref)) self.assertEqual(compat.as_bytes("l1"), container(v3)) self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref)) return constant_op.constant(6.0) with ops.container("l1"): cnd_true = cond_v2.cond_v2(True, true_fn, false_fn) self.assertEquals(cnd_true[0].eval(), 2) cnd_false = cond_v2.cond_v2(False, true_fn, false_fn) self.assertEquals(cnd_false[0].eval(), 6) v4 = variables.Variable([3]) q4 = data_flow_ops.FIFOQueue(1, dtypes.float32) v5 = variables.Variable([4]) q5 = data_flow_ops.FIFOQueue(1, dtypes.float32) self.assertEqual(compat.as_bytes("l1"), container(v4)) self.assertEqual(compat.as_bytes("l1"), container(q4.queue_ref)) self.assertEqual(compat.as_bytes(""), container(v5)) self.assertEqual(compat.as_bytes(""), container(q5.queue_ref))
def _cond(pred, true_fn, false_fn, name): if _is_old_cond(): return control_flow_ops.cond(pred, true_fn, false_fn, name=name) else: return cond_v2.cond_v2(pred, true_fn, false_fn, name=name)
def _add_cond(x): return cond_v2.cond_v2(constant_op.constant(True, name="pred"), lambda: x, lambda: x + 1)
def testContainer(self): """Set containers outside & inside of cond_v2. Make sure the containers are set correctly for both variable creation (tested by variables.Variable) and for stateful ops (tested by FIFOQueue) """ self.skipTest("b/113048653") with ops.Graph().as_default() as g: with self.session(graph=g): v0 = variables.Variable([0]) q0 = data_flow_ops.FIFOQueue(1, dtypes.float32) def container(node): return node.op.get_attr("container") self.assertEqual(compat.as_bytes(""), container(v0)) self.assertEqual(compat.as_bytes(""), container(q0.queue_ref)) def true_fn(): # When this branch is created in cond below, # the container should begin with 'l1' v1 = variables.Variable([1]) q1 = data_flow_ops.FIFOQueue(1, dtypes.float32) with ops.container("l2t"): v2 = variables.Variable([2]) q2 = data_flow_ops.FIFOQueue(1, dtypes.float32) v3 = variables.Variable([1]) q3 = data_flow_ops.FIFOQueue(1, dtypes.float32) self.assertEqual(compat.as_bytes("l1"), container(v1)) self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref)) self.assertEqual(compat.as_bytes("l2t"), container(v2)) self.assertEqual(compat.as_bytes("l2t"), container(q2.queue_ref)) self.assertEqual(compat.as_bytes("l1"), container(v3)) self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref)) return constant_op.constant(2.0) def false_fn(): # When this branch is created in cond below, # the container should begin with 'l1' v1 = variables.Variable([1]) q1 = data_flow_ops.FIFOQueue(1, dtypes.float32) with ops.container("l2f"): v2 = variables.Variable([2]) q2 = data_flow_ops.FIFOQueue(1, dtypes.float32) v3 = variables.Variable([1]) q3 = data_flow_ops.FIFOQueue(1, dtypes.float32) self.assertEqual(compat.as_bytes("l1"), container(v1)) self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref)) self.assertEqual(compat.as_bytes("l2f"), container(v2)) self.assertEqual(compat.as_bytes("l2f"), container(q2.queue_ref)) self.assertEqual(compat.as_bytes("l1"), container(v3)) self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref)) return constant_op.constant(6.0) with ops.container("l1"): cnd_true = cond_v2.cond_v2(constant_op.constant(True), true_fn, false_fn) self.assertEquals(cnd_true.eval(), 2) cnd_false = cond_v2.cond_v2(constant_op.constant(False), true_fn, false_fn) self.assertEquals(cnd_false.eval(), 6) v4 = variables.Variable([3]) q4 = data_flow_ops.FIFOQueue(1, dtypes.float32) v5 = variables.Variable([4]) q5 = data_flow_ops.FIFOQueue(1, dtypes.float32) self.assertEqual(compat.as_bytes("l1"), container(v4)) self.assertEqual(compat.as_bytes("l1"), container(q4.queue_ref)) self.assertEqual(compat.as_bytes(""), container(v5)) self.assertEqual(compat.as_bytes(""), container(q5.queue_ref))
def model(b): return cond_v2.cond_v2(b, true_fn, false_fn)