def testLoweringDisabledInXLA(self): with self.test_session(graph=ops.Graph()) as sess: # Build the cond_v2 in an XLA context xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() out_cond = self._createCond("cond") xla_context.Exit() run_options = config_pb2.RunOptions(output_partition_graphs=True) run_metadata = config_pb2.RunMetadata() sess.run(out_cond, options=run_options, run_metadata=run_metadata) # Lowering disabled in XLA, there should be no `Switch` node switch_found = any( any(node.op == "Switch" for node in graph.node) for graph in run_metadata.partition_graphs) self.assertFalse( switch_found, "A `Switch` op exists, but the graph should not be lowered.") # Lowering disabled in XLA, there should still be an `If` node if_found = any( any(node.op == "If" for node in graph.node) for graph in run_metadata.partition_graphs) self.assertTrue( if_found, "An `If` op was not found, but the graph should not be lowered." )
def testLoweringDisabledInXLA(self): with self.session(graph=ops.Graph()) as sess: # Build the cond_v2 in an XLA context xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() cond_output, cond_op = self._createCond("cond") xla_context.Exit() # Check lowering attr is not set. with self.assertRaises(ValueError): cond_op.get_attr("_lower_using_switch_merge") # Check the actual graph that is run. run_options = config_pb2.RunOptions(output_partition_graphs=True) run_metadata = config_pb2.RunMetadata() sess.run(cond_output, options=run_options, run_metadata=run_metadata) # Lowering disabled in XLA, there should be no `Switch` node self.assertFalse( _has_node_with_op(run_metadata, "Switch"), "A `Switch` op exists, but the graph should not be lowered.") # Lowering disabled in XLA, there should still be an `If` node self.assertTrue( _has_node_with_op(run_metadata, "StatelessIf"), "An `If` op was not found, but the graph should not be lowered.")
def testSwitchCaseConstPropagation(self): self.skipTest("b/127846988") with self.session() as sess, self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() x = array_ops.placeholder(dtypes.float32) p = array_ops.placeholder(dtypes.int32) def branch0(): return 5. def branch1(): return 15. # TODO(b/129021699): Wrapping this in a tf.function does not work. def branch2(): # This emits a StridedSlice op which expects the index to be a # compile-time const. return x[p] output = control_flow_ops.switch_case(constant_op.constant(2), { 0: branch0, 1: branch1, 2: branch2, }) self.assertAllEqual( 7., sess.run(output, feed_dict={ x: [0., 1., 7.], p: 2, })) xla_context.Exit()
def testCondConstPropagation_xlaCompile(self): self.skipTest("b/132430685") with self.session(), self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() x = array_ops.placeholder_with_default([0., 1., 2.], shape=[3]) p = constant_op.constant(1) def f(): # TODO(b/129021699): Wrapping this in a tf.function does not work. def if_true(): # This emits a StridedSlice op which expects the index to be a # compile-time const. return x[p] def if_false(): return 5. return control_flow_ops.cond(constant_op.constant(True), if_true, if_false) output = xla.compile(f) self.assertAllEqual(1., self.evaluate(output)) xla_context.Exit()
def testCondConstPropagation_errorMsg(self): self.skipTest("b/132430685") with self.session() as sess, self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() x = array_ops.placeholder(dtypes.float32) p = random_ops.random_uniform([], minval=1, maxval=3, dtype=dtypes.int32) # TODO(b/129021699): Wrapping this in a tf.function does not work. def if_true(): # This emits a StridedSlice op which expects the index to be a # compile-time const. return x[:p] def if_false(): return array_ops.fill([p], 5.) output = control_flow_ops.cond(constant_op.constant(True), if_true, if_false) with self.assertRaisesRegex(errors.InvalidArgumentError, "must be a compile-time constant"): sess.run(output, feed_dict={ x: [0., 1., 2.], }) xla_context.Exit()
def testCondConstPropagation(self): with self.session() as sess, self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() x = array_ops.placeholder(dtypes.float32) p = array_ops.placeholder(dtypes.int32) # TODO(b/129021699): Wrapping this in a tf.function does not work. def if_true(): # This emits a StridedSlice op which expects the index to be a # compile-time const. return x[p] def if_false(): return 5. output = control_flow_ops.cond(constant_op.constant(True), if_true, if_false) self.assertAllEqual( 1., sess.run(output, feed_dict={ x: [0., 1., 2.], p: 1 })) xla_context.Exit()
def testMap(self): if is_compile_on_demand(): self.skipTest("list_ops are not supported in cpu_ondemand") with self.session(), self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() nums = [1, 2, 3, 4, 5, 6] elems = constant_op.constant(nums, name="data") r = map_fn.map_fn(lambda x: math_ops.multiply(math_ops.add(x, 3), 2), elems) self.assertAllEqual(r, np.array([(x + 3) * 2 for x in nums])) xla_context.Exit()
def _testNestedWhileLoopWithMaxItersFromOuterContext(self): if is_compile_on_demand(): self.skipTest("list_ops are not supported in cpu_ondemand") with self.session() as sess, self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() v = constant_op.constant(1.0) p = array_ops.placeholder(dtype=dtypes.int32) def mid_body_builder(iterations): def mid_body(i, x): r = control_flow_ops.while_loop( lambda *_: True, lambda i, x: (i + 1, v * x), (0, x), maximum_iterations=iterations, name="inner") return (i + 1, gradients_impl.gradients(x + r[1], v)[0]) return mid_body def outer_body(i, x): iterations = array_ops.size(p, name="iterations") return (i + 1, x + control_flow_ops.while_loop( lambda *_: True, mid_body_builder(iterations), (0, x), maximum_iterations=iterations, name="mid")[1]) def create_while_loop(): r = control_flow_ops.while_loop( lambda *_: True, outer_body, (0, 1.0), maximum_iterations=5, name="outer") return array_ops.identity(r[1]) # p:placeholder # j = 0 # i, x = 0, 1. # while j++ < 5: # i1, x1 = 0, x # while i1++ < len(p): # i2, x2 = 0, x1 # while i2++ < len(p): # x2 = v * x2 # x1 = grad(x1 + x2, v) # x = x1 # output = x output = create_while_loop() sess.run(output, feed_dict={p: [0, 0, 0]}) xla_context.Exit()
def _call_for_each_replica(self, fn, args, kwargs): with distribute_lib.ReplicaContext(self._container_strategy(), replica_id_in_sync_group=0), \ ops.device(self._ipu_device): # Make sure it is compiled as a single engine when called in graph mode. # This is similar to the mechanism used by xla.compile. xla_context = control_flow_ops.XLAControlFlowContext() try: xla_context.Enter() _validate_function_for_arguments(fn, args, kwargs) return fn(*args, **kwargs) finally: xla_context.Exit()
def testCondNoInputs(self): """Verifies against `Failed precondition: Expected one input shape`.""" with self.session(), self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() for pred in True, False: cond_out = control_flow_ops.cond( array_ops.placeholder_with_default(pred, []), lambda: constant_op.constant(2.), lambda: constant_op.constant(1.)) self.assertEqual(int(pred) + 1., self.evaluate(cond_out)) xla_context.Exit()
def testNestedLoweringDisabledInXLA(self): # Build the cond_v2 in an XLA context xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() _, cond_op = self._createNestedCond("cond") xla_context.Exit() # Check lowering attr is not set for either If node. with self.assertRaises(ValueError): cond_op.get_attr("_lower_using_switch_merge") nested_if_ops = [] for func in ops.get_default_graph()._functions.values(): nested_if_ops.extend(op for op in func._graph.get_operations() if op.type == "If") self.assertEqual(len(nested_if_ops), 1) with self.assertRaises(ValueError): nested_if_ops[0].get_attr("_lower_using_switch_merge")
def testCondAndTensorArrayInDefun(self): with self.cached_session(), self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() @function.defun def f(): ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1) output = control_flow_ops.cond(constant_op.constant(True), lambda: ta.write(0, 5.), lambda: ta.write(0, 10.)) return output.stack() output_t = f() self.assertAllEqual(self.evaluate(output_t), [5.]) xla_context.Exit()
def testLoweringDisabledInXLA(self): with self.session(graph=ops.Graph()) as sess: # Build the cond_v2 in an XLA context xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() cond_output, cond_op = self._createCond("cond") xla_context.Exit() # Check lowering attr is not set. with self.assertRaises(ValueError): cond_op.get_attr("_lower_using_switch_merge") # Check the actual graph that is run. run_options = config_pb2.RunOptions(output_partition_graphs=True) run_metadata = config_pb2.RunMetadata() sess.run(cond_output, options=run_options, run_metadata=run_metadata) # Lowering disabled in XLA, there should be no `Switch` node self.assertFalse( _has_node_with_op(run_metadata, "Switch"), "A `Switch` op exists, but the graph should not be lowered.") if test_util.is_xla_enabled(): # If XLA is actually enabled then we expect the StatelessIf to have been # put inside an XLA cluster. self.assertFalse( _has_node_with_op(run_metadata, "StatelessIf"), ("A `StatelessIf` op was found, but the node should have been " + "clustered.")) self.assertTrue( _has_node_with_op(run_metadata, "_XlaCompile"), ("An `_XlaCompile` op was not found, but the `StatelessIf` (at " + "least) op should have been clustered.")) self.assertTrue( _has_node_with_op(run_metadata, "_XlaRun"), ("An `_XlaRun` op was not found, but the `StatelessIf` (at " + "least) op should have been clustered.")) else: # Lowering disabled in XLA, there should still be an `If` node self.assertTrue( _has_node_with_op(run_metadata, "StatelessIf"), ("A `StatelessIf` op was not found, but the graph should not be " + "lowered."))
def testCondAndTensorArrayInDefun(self): # TODO(b/132430685): Make test more useful. Also b/129396295, b/127846988 with self.session(), self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() @function.defun def f(): ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1) output = control_flow_ops.cond(constant_op.constant(True), lambda: ta.write(0, 5.), lambda: ta.write(0, 10.)) return output.stack() output_t = f() self.assertAllEqual([5.], self.evaluate(output_t)) xla_context.Exit()
def testCondAndTensorArrayInDefun_constFolding(self): g = ops.Graph() with session.Session(graph=g), g.as_default(), self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() @function.defun def f(): ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1) output = control_flow_ops.cond(constant_op.constant(False), lambda: ta.write(0, 5.), lambda: ta.write(0, 10.)) return output.stack() output_t = f() self.assertAllEqual([10.], self.evaluate(output_t)) xla_context.Exit()
def testCondAndTensorArray_xlaCompile(self): self.skipTest("b/127846988") # Fails with "Uninitialized arguments" in XlaIfOp::Compile with self.session(), self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() def f(): ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1) output = control_flow_ops.cond(constant_op.constant(True), lambda: ta.write(0, 5.), lambda: ta.write(0, 10.)) return output.stack() output_t, = xla.compile(f) self.assertAllEqual([5.], self.evaluate(output_t)) xla_context.Exit()
def testSwitchCaseAndTensorArray_xlaCompile(self): self.skipTest("b/127846988") with self.session(), self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() def f(): ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1) output = control_flow_ops.switch_case( constant_op.constant(1), { 0: lambda: ta.write(0, 5.), 1: lambda: ta.write(0, 10.), 2: lambda: ta.write(0, 15.), }) return output.stack() output_t, = xla.compile(f) self.assertAllEqual([10.], self.evaluate(output_t)) xla_context.Exit()
def __call__(self, *args, **kwds): """Calls the graph function and warn too frequent tracings.""" context.ensure_initialized() if RUN_FUNCTIONS_EAGERLY: return self._python_function(*args, **kwds) tracing_count = self._get_tracing_count() if self._experimental_compile: # V2 control flow relies on XLAControlFlowContext to generate a # XLA-compatible function graph. xla_context = control_flow_ops.XLAControlFlowContext() try: xla_context.Enter() result = self._call(*args, **kwds) finally: xla_context.Exit() else: result = self._call(*args, **kwds) if tracing_count == self._get_tracing_count(): self._call_counter.called_without_tracing() return result self._call_counter.called_with_tracing() recent_tracing_count = self._call_counter.get_tracing_count() if recent_tracing_count >= FREQUENT_TRACING_WARNING_THRESHOLD: logging.warning( "{} out of the last {} calls to {} triggered tf.function retracing. " "Tracing is expensive and the excessive number of tracings is likely " "due to passing python objects instead of tensors. Also, tf.function " "has experimental_relax_shapes=True option that relaxes argument " "shapes that can avoid unnecessary retracing. Please refer to " "https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args" " and https://www.tensorflow.org/api_docs/python/tf/function for more " "details.".format(recent_tracing_count, self._call_counter.call_count, self._python_function)) return result
def testNoOptionalsInXla(self): @def_function.function def func_with_cond(): pred = constant_op.constant(True, name="pred") x = constant_op.constant(1.0, name="x") def true_fn(): intermediate = x + 1 return intermediate * x def false_fn(): return x + 1 output = cond_v2.cond_v2(pred, true_fn, false_fn) grad = gradients_impl.gradients(output, x)[0] forward_if_op = output.op.inputs[0].op gradient_if_op = grad.op.inputs[0].op def verify_no_optional_ops(op, branch_name): branch_function = ops.get_default_graph()._get_function( op.get_attr(branch_name).name) function_def = branch_function.definition for node_def in function_def.node_def: self.assertNotIn(node_def.op, _OPTIONAL_OPS) verify_no_optional_ops(forward_if_op, "then_branch") verify_no_optional_ops(forward_if_op, "else_branch") verify_no_optional_ops(gradient_if_op, "then_branch") verify_no_optional_ops(gradient_if_op, "else_branch") return grad xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() func_with_cond() xla_context.Exit()
def _GetNodeNames(use_xla): with self.session(): input_tensor = array_ops.placeholder(np.float32, shape=input_sizes) if use_xla: with self.test_scope(): # pylint: disable=protected-access graph = ops.get_default_graph() graph._set_control_flow_context( control_flow_ops.XLAControlFlowContext()) # pylint: enable=protected-access conv2d_op = layer(filters=64, kernel_size=filter_sizes, dilation_rate=dilations, padding="same") _ = conv2d_op(input_tensor) return [ n.name for n in ops.get_default_graph().as_graph_def().node ] else: with ops.device("CPU"): conv2d_op = layer(filters=64, kernel_size=filter_sizes, dilation_rate=dilations, padding="same") _ = conv2d_op(input_tensor) names = [ n.name for n in ops.get_default_graph().as_graph_def().node ] # filter out space to depth ops. return [ name for name in names if "space" not in name and "Space" not in name ]
def _testMaxItersSimple(self): if is_compile_on_demand(): self.skipTest("list_ops are not supported in cpu_ondemand") with self.session() as sess, self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() v = constant_op.constant(1.0) p = array_ops.placeholder(dtype=dtypes.int32) def create_while_loop(): iterations = array_ops.size(p, name="iterations") r = control_flow_ops.while_loop( lambda *_: True, lambda i, x: (i + 1, v * x), (0, 1.0), maximum_iterations=iterations, name="outer") return array_ops.identity(r[1]) output = create_while_loop() output = gradients_impl.gradients(output, v)[0] result = sess.run(output, feed_dict={p: [0, 0, 0]}) print(result) xla_context.Exit()