Ejemplo n.º 1
0
    def testLoweringDisabledInXLA(self):
        with self.test_session(graph=ops.Graph()) as sess:
            # Build the cond_v2 in an XLA context
            xla_context = control_flow_ops.XLAControlFlowContext()
            xla_context.Enter()
            out_cond = self._createCond("cond")
            xla_context.Exit()

            run_options = config_pb2.RunOptions(output_partition_graphs=True)
            run_metadata = config_pb2.RunMetadata()
            sess.run(out_cond, options=run_options, run_metadata=run_metadata)

            # Lowering disabled in XLA, there should be no `Switch` node
            switch_found = any(
                any(node.op == "Switch" for node in graph.node)
                for graph in run_metadata.partition_graphs)

            self.assertFalse(
                switch_found,
                "A `Switch` op exists, but the graph should not be lowered.")

            # Lowering disabled in XLA, there should still be an `If` node
            if_found = any(
                any(node.op == "If" for node in graph.node)
                for graph in run_metadata.partition_graphs)

            self.assertTrue(
                if_found,
                "An `If` op was not found, but the graph should not be lowered."
            )
Ejemplo n.º 2
0
  def testLoweringDisabledInXLA(self):
    with self.session(graph=ops.Graph()) as sess:
      # Build the cond_v2 in an XLA context
      xla_context = control_flow_ops.XLAControlFlowContext()
      xla_context.Enter()
      cond_output, cond_op = self._createCond("cond")
      xla_context.Exit()

      # Check lowering attr is not set.
      with self.assertRaises(ValueError):
        cond_op.get_attr("_lower_using_switch_merge")

      # Check the actual graph that is run.
      run_options = config_pb2.RunOptions(output_partition_graphs=True)
      run_metadata = config_pb2.RunMetadata()
      sess.run(cond_output, options=run_options, run_metadata=run_metadata)

      # Lowering disabled in XLA, there should be no `Switch` node
      self.assertFalse(
          _has_node_with_op(run_metadata, "Switch"),
          "A `Switch` op exists, but the graph should not be lowered.")

      # Lowering disabled in XLA, there should still be an `If` node
      self.assertTrue(
          _has_node_with_op(run_metadata, "StatelessIf"),
          "An `If` op was not found, but the graph should not be lowered.")
Ejemplo n.º 3
0
    def testSwitchCaseConstPropagation(self):
        self.skipTest("b/127846988")
        with self.session() as sess, self.test_scope():
            xla_context = control_flow_ops.XLAControlFlowContext()
            xla_context.Enter()

            x = array_ops.placeholder(dtypes.float32)
            p = array_ops.placeholder(dtypes.int32)

            def branch0():
                return 5.

            def branch1():
                return 15.

            # TODO(b/129021699): Wrapping this in a tf.function does not work.
            def branch2():
                # This emits a StridedSlice op which expects the index to be a
                # compile-time const.
                return x[p]

            output = control_flow_ops.switch_case(constant_op.constant(2), {
                0: branch0,
                1: branch1,
                2: branch2,
            })

            self.assertAllEqual(
                7., sess.run(output, feed_dict={
                    x: [0., 1., 7.],
                    p: 2,
                }))

            xla_context.Exit()
Ejemplo n.º 4
0
    def testCondConstPropagation_xlaCompile(self):
        self.skipTest("b/132430685")
        with self.session(), self.test_scope():
            xla_context = control_flow_ops.XLAControlFlowContext()
            xla_context.Enter()

            x = array_ops.placeholder_with_default([0., 1., 2.], shape=[3])
            p = constant_op.constant(1)

            def f():
                # TODO(b/129021699): Wrapping this in a tf.function does not work.
                def if_true():
                    # This emits a StridedSlice op which expects the index to be a
                    # compile-time const.
                    return x[p]

                def if_false():
                    return 5.

                return control_flow_ops.cond(constant_op.constant(True),
                                             if_true, if_false)

            output = xla.compile(f)

            self.assertAllEqual(1., self.evaluate(output))

            xla_context.Exit()
Ejemplo n.º 5
0
    def testCondConstPropagation_errorMsg(self):
        self.skipTest("b/132430685")
        with self.session() as sess, self.test_scope():
            xla_context = control_flow_ops.XLAControlFlowContext()
            xla_context.Enter()

            x = array_ops.placeholder(dtypes.float32)
            p = random_ops.random_uniform([],
                                          minval=1,
                                          maxval=3,
                                          dtype=dtypes.int32)

            # TODO(b/129021699): Wrapping this in a tf.function does not work.
            def if_true():
                # This emits a StridedSlice op which expects the index to be a
                # compile-time const.
                return x[:p]

            def if_false():
                return array_ops.fill([p], 5.)

            output = control_flow_ops.cond(constant_op.constant(True), if_true,
                                           if_false)

            with self.assertRaisesRegex(errors.InvalidArgumentError,
                                        "must be a compile-time constant"):
                sess.run(output, feed_dict={
                    x: [0., 1., 2.],
                })

            xla_context.Exit()
Ejemplo n.º 6
0
    def testCondConstPropagation(self):
        with self.session() as sess, self.test_scope():
            xla_context = control_flow_ops.XLAControlFlowContext()
            xla_context.Enter()

            x = array_ops.placeholder(dtypes.float32)
            p = array_ops.placeholder(dtypes.int32)

            # TODO(b/129021699): Wrapping this in a tf.function does not work.
            def if_true():
                # This emits a StridedSlice op which expects the index to be a
                # compile-time const.
                return x[p]

            def if_false():
                return 5.

            output = control_flow_ops.cond(constant_op.constant(True), if_true,
                                           if_false)

            self.assertAllEqual(
                1., sess.run(output, feed_dict={
                    x: [0., 1., 2.],
                    p: 1
                }))

            xla_context.Exit()
Ejemplo n.º 7
0
 def testMap(self):
   if is_compile_on_demand():
     self.skipTest("list_ops are not supported in cpu_ondemand")
   with self.session(), self.test_scope():
     xla_context = control_flow_ops.XLAControlFlowContext()
     xla_context.Enter()
     nums = [1, 2, 3, 4, 5, 6]
     elems = constant_op.constant(nums, name="data")
     r = map_fn.map_fn(lambda x: math_ops.multiply(math_ops.add(x, 3), 2),
                       elems)
     self.assertAllEqual(r, np.array([(x + 3) * 2 for x in nums]))
     xla_context.Exit()
Ejemplo n.º 8
0
  def _testNestedWhileLoopWithMaxItersFromOuterContext(self):
    if is_compile_on_demand():
      self.skipTest("list_ops are not supported in cpu_ondemand")
    with self.session() as sess, self.test_scope():
      xla_context = control_flow_ops.XLAControlFlowContext()
      xla_context.Enter()
      v = constant_op.constant(1.0)
      p = array_ops.placeholder(dtype=dtypes.int32)

      def mid_body_builder(iterations):

        def mid_body(i, x):
          r = control_flow_ops.while_loop(
              lambda *_: True,
              lambda i, x: (i + 1, v * x), (0, x),
              maximum_iterations=iterations,
              name="inner")
          return (i + 1, gradients_impl.gradients(x + r[1], v)[0])

        return mid_body

      def outer_body(i, x):
        iterations = array_ops.size(p, name="iterations")
        return (i + 1, x + control_flow_ops.while_loop(
            lambda *_: True,
            mid_body_builder(iterations), (0, x),
            maximum_iterations=iterations,
            name="mid")[1])

      def create_while_loop():
        r = control_flow_ops.while_loop(
            lambda *_: True,
            outer_body, (0, 1.0),
            maximum_iterations=5,
            name="outer")
        return array_ops.identity(r[1])

      # p:placeholder
      # j = 0
      # i, x = 0, 1.
      # while j++ < 5:
      #   i1, x1 = 0, x
      #   while i1++ < len(p):
      #     i2, x2 = 0, x1
      #     while i2++ < len(p):
      #       x2 = v * x2
      #     x1 = grad(x1 + x2, v)
      #   x = x1
      # output = x
      output = create_while_loop()
      sess.run(output, feed_dict={p: [0, 0, 0]})
      xla_context.Exit()
Ejemplo n.º 9
0
 def _call_for_each_replica(self, fn, args, kwargs):
     with distribute_lib.ReplicaContext(self._container_strategy(),
                                        replica_id_in_sync_group=0), \
         ops.device(self._ipu_device):
         # Make sure it is compiled as a single engine when called in graph mode.
         # This is similar to the mechanism used by xla.compile.
         xla_context = control_flow_ops.XLAControlFlowContext()
         try:
             xla_context.Enter()
             _validate_function_for_arguments(fn, args, kwargs)
             return fn(*args, **kwargs)
         finally:
             xla_context.Exit()
Ejemplo n.º 10
0
    def testCondNoInputs(self):
        """Verifies against `Failed precondition: Expected one input shape`."""

        with self.session(), self.test_scope():
            xla_context = control_flow_ops.XLAControlFlowContext()
            xla_context.Enter()

            for pred in True, False:
                cond_out = control_flow_ops.cond(
                    array_ops.placeholder_with_default(pred, []),
                    lambda: constant_op.constant(2.),
                    lambda: constant_op.constant(1.))
                self.assertEqual(int(pred) + 1., self.evaluate(cond_out))

            xla_context.Exit()
Ejemplo n.º 11
0
    def testNestedLoweringDisabledInXLA(self):
        # Build the cond_v2 in an XLA context
        xla_context = control_flow_ops.XLAControlFlowContext()
        xla_context.Enter()
        _, cond_op = self._createNestedCond("cond")
        xla_context.Exit()

        # Check lowering attr is not set for either If node.
        with self.assertRaises(ValueError):
            cond_op.get_attr("_lower_using_switch_merge")

        nested_if_ops = []
        for func in ops.get_default_graph()._functions.values():
            nested_if_ops.extend(op for op in func._graph.get_operations()
                                 if op.type == "If")
        self.assertEqual(len(nested_if_ops), 1)
        with self.assertRaises(ValueError):
            nested_if_ops[0].get_attr("_lower_using_switch_merge")
Ejemplo n.º 12
0
    def testCondAndTensorArrayInDefun(self):
        with self.cached_session(), self.test_scope():
            xla_context = control_flow_ops.XLAControlFlowContext()
            xla_context.Enter()

            @function.defun
            def f():
                ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
                output = control_flow_ops.cond(constant_op.constant(True),
                                               lambda: ta.write(0, 5.),
                                               lambda: ta.write(0, 10.))

                return output.stack()

            output_t = f()
            self.assertAllEqual(self.evaluate(output_t), [5.])

            xla_context.Exit()
Ejemplo n.º 13
0
  def testLoweringDisabledInXLA(self):
    with self.session(graph=ops.Graph()) as sess:
      # Build the cond_v2 in an XLA context
      xla_context = control_flow_ops.XLAControlFlowContext()
      xla_context.Enter()
      cond_output, cond_op = self._createCond("cond")
      xla_context.Exit()

      # Check lowering attr is not set.
      with self.assertRaises(ValueError):
        cond_op.get_attr("_lower_using_switch_merge")

      # Check the actual graph that is run.
      run_options = config_pb2.RunOptions(output_partition_graphs=True)
      run_metadata = config_pb2.RunMetadata()
      sess.run(cond_output, options=run_options, run_metadata=run_metadata)

      # Lowering disabled in XLA, there should be no `Switch` node
      self.assertFalse(
          _has_node_with_op(run_metadata, "Switch"),
          "A `Switch` op exists, but the graph should not be lowered.")

      if test_util.is_xla_enabled():
        # If XLA is actually enabled then we expect the StatelessIf to have been
        # put inside an XLA cluster.
        self.assertFalse(
            _has_node_with_op(run_metadata, "StatelessIf"),
            ("A `StatelessIf` op was found, but the node should have been " +
             "clustered."))
        self.assertTrue(
            _has_node_with_op(run_metadata, "_XlaCompile"),
            ("An `_XlaCompile` op was not found, but the `StatelessIf` (at " +
             "least) op should have been clustered."))
        self.assertTrue(
            _has_node_with_op(run_metadata, "_XlaRun"),
            ("An `_XlaRun` op was not found, but the `StatelessIf` (at " +
             "least) op should have been clustered."))
      else:
        # Lowering disabled in XLA, there should still be an `If` node
        self.assertTrue(
            _has_node_with_op(run_metadata, "StatelessIf"),
            ("A `StatelessIf` op was not found, but the graph should not be " +
             "lowered."))
Ejemplo n.º 14
0
    def testCondAndTensorArrayInDefun(self):
        # TODO(b/132430685): Make test more useful. Also b/129396295, b/127846988
        with self.session(), self.test_scope():
            xla_context = control_flow_ops.XLAControlFlowContext()
            xla_context.Enter()

            @function.defun
            def f():
                ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
                output = control_flow_ops.cond(constant_op.constant(True),
                                               lambda: ta.write(0, 5.),
                                               lambda: ta.write(0, 10.))

                return output.stack()

            output_t = f()
            self.assertAllEqual([5.], self.evaluate(output_t))

            xla_context.Exit()
Ejemplo n.º 15
0
    def testCondAndTensorArrayInDefun_constFolding(self):
        g = ops.Graph()
        with session.Session(graph=g), g.as_default(), self.test_scope():
            xla_context = control_flow_ops.XLAControlFlowContext()
            xla_context.Enter()

            @function.defun
            def f():
                ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
                output = control_flow_ops.cond(constant_op.constant(False),
                                               lambda: ta.write(0, 5.),
                                               lambda: ta.write(0, 10.))

                return output.stack()

            output_t = f()
            self.assertAllEqual([10.], self.evaluate(output_t))

            xla_context.Exit()
Ejemplo n.º 16
0
    def testCondAndTensorArray_xlaCompile(self):
        self.skipTest("b/127846988")
        # Fails with "Uninitialized arguments" in XlaIfOp::Compile
        with self.session(), self.test_scope():
            xla_context = control_flow_ops.XLAControlFlowContext()
            xla_context.Enter()

            def f():
                ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
                output = control_flow_ops.cond(constant_op.constant(True),
                                               lambda: ta.write(0, 5.),
                                               lambda: ta.write(0, 10.))

                return output.stack()

            output_t, = xla.compile(f)
            self.assertAllEqual([5.], self.evaluate(output_t))

            xla_context.Exit()
Ejemplo n.º 17
0
    def testSwitchCaseAndTensorArray_xlaCompile(self):
        self.skipTest("b/127846988")
        with self.session(), self.test_scope():
            xla_context = control_flow_ops.XLAControlFlowContext()
            xla_context.Enter()

            def f():
                ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1)
                output = control_flow_ops.switch_case(
                    constant_op.constant(1), {
                        0: lambda: ta.write(0, 5.),
                        1: lambda: ta.write(0, 10.),
                        2: lambda: ta.write(0, 15.),
                    })

                return output.stack()

            output_t, = xla.compile(f)
            self.assertAllEqual([10.], self.evaluate(output_t))

            xla_context.Exit()
Ejemplo n.º 18
0
  def __call__(self, *args, **kwds):
    """Calls the graph function and warn too frequent tracings."""
    context.ensure_initialized()
    if RUN_FUNCTIONS_EAGERLY:
      return self._python_function(*args, **kwds)

    tracing_count = self._get_tracing_count()
    if self._experimental_compile:
      # V2 control flow relies on XLAControlFlowContext to generate a
      # XLA-compatible function graph.
      xla_context = control_flow_ops.XLAControlFlowContext()
      try:
        xla_context.Enter()
        result = self._call(*args, **kwds)
      finally:
        xla_context.Exit()
    else:
      result = self._call(*args, **kwds)

    if tracing_count == self._get_tracing_count():
      self._call_counter.called_without_tracing()
      return result

    self._call_counter.called_with_tracing()
    recent_tracing_count = self._call_counter.get_tracing_count()
    if recent_tracing_count >= FREQUENT_TRACING_WARNING_THRESHOLD:
      logging.warning(
          "{} out of the last {} calls to {} triggered tf.function retracing. "
          "Tracing is expensive and the excessive number of tracings is likely "
          "due to passing python objects instead of tensors. Also, tf.function "
          "has experimental_relax_shapes=True option that relaxes argument "
          "shapes that can avoid unnecessary retracing. Please refer to "
          "https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args"
          " and https://www.tensorflow.org/api_docs/python/tf/function for more "
          "details.".format(recent_tracing_count, self._call_counter.call_count,
                            self._python_function))

    return result
Ejemplo n.º 19
0
  def testNoOptionalsInXla(self):

    @def_function.function
    def func_with_cond():
      pred = constant_op.constant(True, name="pred")
      x = constant_op.constant(1.0, name="x")

      def true_fn():
        intermediate = x + 1
        return intermediate * x

      def false_fn():
        return x + 1

      output = cond_v2.cond_v2(pred, true_fn, false_fn)
      grad = gradients_impl.gradients(output, x)[0]

      forward_if_op = output.op.inputs[0].op
      gradient_if_op = grad.op.inputs[0].op

      def verify_no_optional_ops(op, branch_name):
        branch_function = ops.get_default_graph()._get_function(
            op.get_attr(branch_name).name)
        function_def = branch_function.definition
        for node_def in function_def.node_def:
          self.assertNotIn(node_def.op, _OPTIONAL_OPS)

      verify_no_optional_ops(forward_if_op, "then_branch")
      verify_no_optional_ops(forward_if_op, "else_branch")
      verify_no_optional_ops(gradient_if_op, "then_branch")
      verify_no_optional_ops(gradient_if_op, "else_branch")

      return grad

    xla_context = control_flow_ops.XLAControlFlowContext()
    xla_context.Enter()
    func_with_cond()
    xla_context.Exit()
Ejemplo n.º 20
0
        def _GetNodeNames(use_xla):
            with self.session():
                input_tensor = array_ops.placeholder(np.float32,
                                                     shape=input_sizes)

                if use_xla:
                    with self.test_scope():
                        # pylint: disable=protected-access
                        graph = ops.get_default_graph()
                        graph._set_control_flow_context(
                            control_flow_ops.XLAControlFlowContext())
                        # pylint: enable=protected-access
                        conv2d_op = layer(filters=64,
                                          kernel_size=filter_sizes,
                                          dilation_rate=dilations,
                                          padding="same")
                        _ = conv2d_op(input_tensor)
                        return [
                            n.name for n in
                            ops.get_default_graph().as_graph_def().node
                        ]
                else:
                    with ops.device("CPU"):
                        conv2d_op = layer(filters=64,
                                          kernel_size=filter_sizes,
                                          dilation_rate=dilations,
                                          padding="same")
                        _ = conv2d_op(input_tensor)
                        names = [
                            n.name for n in
                            ops.get_default_graph().as_graph_def().node
                        ]
                        # filter out space to depth ops.
                        return [
                            name for name in names
                            if "space" not in name and "Space" not in name
                        ]
Ejemplo n.º 21
0
  def _testMaxItersSimple(self):
    if is_compile_on_demand():
      self.skipTest("list_ops are not supported in cpu_ondemand")
    with self.session() as sess, self.test_scope():
      xla_context = control_flow_ops.XLAControlFlowContext()
      xla_context.Enter()
      v = constant_op.constant(1.0)
      p = array_ops.placeholder(dtype=dtypes.int32)

      def create_while_loop():
        iterations = array_ops.size(p, name="iterations")
        r = control_flow_ops.while_loop(
            lambda *_: True,
            lambda i, x: (i + 1, v * x), (0, 1.0),
            maximum_iterations=iterations,
            name="outer")
        return array_ops.identity(r[1])

      output = create_while_loop()
      output = gradients_impl.gradients(output, v)[0]

      result = sess.run(output, feed_dict={p: [0, 0, 0]})
      print(result)
      xla_context.Exit()