Exemple #1
0
    def simpleTest(self, arg0, arg1, global_jit_level):
        config = config_pb2.ConfigProto()
        config.graph_options.optimizer_options.global_jit_level = global_jit_level

        with session_lib.Session(config=config) as sess:
            a1 = array_ops.placeholder(dtypes.float32, [2, 2], name="a1")
            a2 = array_ops.placeholder(dtypes.float32, [2, 2], name="a2")
            # Two element-wise ops. We need at least two ops since single
            # element clusters are not passed to XLA in fusion_only mode.
            a3 = a1 * a2
            a4 = a3 + a1
            # A matmul to break XLA clustering.
            a5 = math_ops.matmul(a4, a1)
            # Two more element-wise ops.
            a6 = a5 - a4
            a7 = a6 + a2

            run_metadata = config_pb2.RunMetadata()
            output = test_utils.RunWithWarmup(
                sess,
                a7, {
                    a1: arg0,
                    a2: arg1
                },
                run_metadata=run_metadata,
                options=config_pb2.RunOptions(
                    trace_level=config_pb2.RunOptions.FULL_TRACE))

            labels = RunMetadataLabels(run_metadata)

            xla_compile_count = sum("XlaCompile(" in x for x in labels)
            xla_run_count = sum("XlaRun(" in x for x in labels)
            self.assertEqual(xla_compile_count, xla_run_count)

            return output, xla_run_count
Exemple #2
0
    def testCond(self):
        """Tests that compilation handles switch operators."""

        with self.session(config=NoRewriteSessionConfig()) as session:
            x = array_ops.placeholder(dtypes.float32)
            y = array_ops.placeholder(dtypes.float32)
            c = array_ops.placeholder(dtypes.bool)
            with jit_scope():
                z = x + 1.0
                w = control_flow_ops.cond(c, lambda: z, lambda: y)
                t = math_ops.add(z, w)

            # If JIT compilation chooses to cluster z and t, then execution will
            # deadlock.

            run_metadata = config_pb2.RunMetadata()
            result = test_utils.RunWithWarmup(
                session,
                t, {
                    x: np.float32(2),
                    y: np.float32(4),
                    c: True
                },
                run_metadata=run_metadata,
                options=config_pb2.RunOptions(
                    trace_level=config_pb2.RunOptions.FULL_TRACE))
            self.assert_(MetadataHasXlaRunOp(run_metadata))
            self.assertAllClose(result, np.float32(6), rtol=1e-1)
Exemple #3
0
        def _Run(compiled):
            @function.Defun(compiled=compiled)
            def Forward(x):
                return math_ops.log(x)

            g = ops.Graph()
            with g.as_default():
                x = array_ops.placeholder(dtypes.float32)
                y = Forward(x)
                dx, = gradients_impl.gradients(y, [x], 1.0)

            cfg = NoRewriteSessionConfig()
            cfg.graph_options.optimizer_options.opt_level = (
                config_pb2.OptimizerOptions.L1)
            cfg.graph_options.optimizer_options.do_function_inlining = True
            with session_lib.Session(graph=g, config=cfg) as sess:
                run_metadata = config_pb2.RunMetadata()
                dx_val = test_utils.RunWithWarmup(
                    sess,
                    dx,
                    feed_dict={x: 100.},
                    run_metadata=run_metadata,
                    options=config_pb2.RunOptions(
                        trace_level=config_pb2.RunOptions.FULL_TRACE))
            self.assertAllClose(dx_val, 0.01)
            return RunMetadataLabels(run_metadata)
Exemple #4
0
    def testIgnoredArguments(self):
        """Tests that JIT computations can ignore formal parameters."""

        with self.session(config=NoRewriteSessionConfig()) as sess:
            x = array_ops.placeholder(dtypes.int32)
            y = array_ops.placeholder(dtypes.int32)
            with jit_scope():
                z = math_ops.add(x, x)
                w = math_ops.add(y, y)
                # Pulls 'w' into the same compilation via control dependencies.
                with ops.control_dependencies([w]):
                    n = control_flow_ops.no_op()
                with ops.control_dependencies([n]):
                    t = math_ops.add(z, z)

            run_metadata = config_pb2.RunMetadata()
            out = test_utils.RunWithWarmup(
                sess,
                t, {
                    x: np.int32(7),
                    y: np.int32(404)
                },
                run_metadata=run_metadata,
                options=config_pb2.RunOptions(
                    trace_level=config_pb2.RunOptions.FULL_TRACE))
            self.assert_(MetadataHasXlaRunOp(run_metadata))
            self.assertAllClose(28, out)
Exemple #5
0
    def testReshape(self):
        """Tests an operator with compile-time constant and non-constant inputs."""

        with self.session(config=NoRewriteSessionConfig()) as sess:
            x = array_ops.placeholder(dtypes.float32)
            y = array_ops.placeholder(dtypes.int32)
            with jit_scope():
                # Reshape's first argument is non-constant in the JIT, but its second
                # (shape) argument will be treated as a compile-time constant for
                # each JIT compilation.
                # We do not use a tf.const() argument since we want to ensure the
                # shape is still a run-time argument to the JIT, and not
                # statically known as part of the JIT compilation's input graph.
                z = array_ops.reshape(x, y)
            run_metadata = config_pb2.RunMetadata()
            out = test_utils.RunWithWarmup(
                sess,
                z, {
                    x: np.array([1, 2, 3, 4, 5, 6], np.float32),
                    y: [-1, 3]
                },
                run_metadata=run_metadata,
                options=config_pb2.RunOptions(
                    trace_level=config_pb2.RunOptions.FULL_TRACE))
            self.assert_(MetadataHasXlaRunOp(run_metadata))
            self.assertAllClose(np.array([[1, 2, 3], [4, 5, 6]], np.float32),
                                out)
Exemple #6
0
    def testDenseLayerJitScopeDefinedShape(self):
        """Tests that the dense layer node is properly compiled in jit scope.

    Dense layer with static shape input tensor should be compiled into a single
    XlaCompile/XlaRun op pair by XLA.
    """

        with self.cached_session() as sess:
            x = array_ops.placeholder(shape=[2, 2, 3], dtype=np.float32)
            with jit_scope():
                y = layers.dense(x, 3)

            self.evaluate(variables.initialize_all_variables())
            run_metadata = config_pb2.RunMetadata()
            test_utils.RunWithWarmup(
                sess,
                y, {
                    x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]
                                 ])
                },
                run_metadata=run_metadata,
                options=config_pb2.RunOptions(
                    trace_level=config_pb2.RunOptions.FULL_TRACE))

        labels = GetRunMetadataLabels(run_metadata)
        self.assertEqual(1, self.countXlaOps(labels))
Exemple #7
0
    def testDenseLayerAutoJit(self):
        """Tests dense layer compilation in auto-jit mode.

    Dense layer should be compiled into a single XlaCompile/XlaRun op pair in
    auto-jit mode.
    """

        os.environ["TF_XLA_FLAGS"] = ("--tf_xla_cpu_global_jit " +
                                      os.environ.get("TF_XLA_FLAGS", ""))
        config = config_pb2.ConfigProto()
        config.graph_options.optimizer_options.global_jit_level = (
            config_pb2.OptimizerOptions.ON_1)

        with self.session(config=config) as sess:
            x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32)
            y = layers.dense(x, 3)

            self.evaluate(variables.initialize_all_variables())
            run_metadata = config_pb2.RunMetadata()
            test_utils.RunWithWarmup(
                sess,
                y, {
                    x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]
                                 ])
                },
                run_metadata=run_metadata,
                options=config_pb2.RunOptions(
                    trace_level=config_pb2.RunOptions.FULL_TRACE))

        labels = GetRunMetadataLabels(run_metadata)
        self.assertEqual(1, self.countXlaOps(labels))
        self.assertFalse(InLabels(labels, "MatMult"))
Exemple #8
0
    def testDenseLayerJitScopeUndefinedShape(self):
        """Tests that the dense layer node is properly compiled in jit scope.

    Dense layer uses shape op to get shape of input tensor if its shape is not
    fully defined. XLA does not cluster shape op with other operators. But in
    experimental_jit_scope, XLA is forced to compile shape op into its own
    cluster, causing dense layer to be split into TWO XlaCompile/XlaRun op
    pairs.
    """

        with self.cached_session() as sess:
            x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32)
            with jit_scope():
                y = layers.dense(x, 3)

            self.evaluate(variables.initialize_all_variables())
            run_metadata = config_pb2.RunMetadata()
            test_utils.RunWithWarmup(
                sess,
                y, {
                    x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]
                                 ])
                },
                run_metadata=run_metadata,
                options=config_pb2.RunOptions(
                    trace_level=config_pb2.RunOptions.FULL_TRACE))

        labels = GetRunMetadataLabels(run_metadata)
        self.assertEqual(2, self.countXlaOps(labels))
        self.assertFalse(InLabels(labels, "MatMult"))
Exemple #9
0
    def testNoOutputs(self):
        with session_lib.Session() as sess:

            # Check that calling the result as a compiled kernel doesn't crash.
            @function.Defun(compiled=True)
            def KernelWithNoOutputs():
                a = constant_op.constant(100)  # pylint: disable=unused-variable

            call = KernelWithNoOutputs()  # pylint: disable=assignment-from-no-return
            test_utils.RunWithWarmup(sess, call, {})
Exemple #10
0
    def _compare(self,
                 fn,
                 args,
                 require_kernel_launch=True,
                 name=None,
                 noinline=None):
        with session_lib.Session(config=NoRewriteSessionConfig()) as sess:
            placeholders = []
            feeds = {}
            for arg in args:
                placeholder = array_ops.placeholder(dtypes.as_dtype(arg.dtype),
                                                    list(arg.shape))
                placeholders.append(placeholder)
                feeds[placeholder] = arg

            compiled_op = CompiledKernel(fn,
                                         *placeholders,
                                         name=name,
                                         noinline=noinline)
            direct_op = fn(*placeholders)

            run_metadata = config_pb2.RunMetadata()
            compiled = test_utils.RunWithWarmup(
                sess, compiled_op, feeds,
                config_pb2.RunOptions(
                    trace_level=config_pb2.RunOptions.FULL_TRACE),
                run_metadata)
            print("Compiled Result {}".format(compiled))

            if require_kernel_launch:
                self.assert_(MetadataHasXlaRunOp(run_metadata))

                direct = sess.run(direct_op, feeds)
                print("Direct Result {}".format(direct))

                if (isinstance(compiled, (tuple, list))
                        and (isinstance(direct, (tuple, list)))):
                    for (x, y) in zip(compiled, direct):
                        self.assertAllClose(x, y, rtol=1e-1)
                else:
                    self.assertAllClose(compiled, direct, rtol=1e-2)
Exemple #11
0
    def testExplicitMarking(self):
        """Test explicit marking of operators to compile."""
        batch_size = 16
        image_size = 28 * 28
        num_classes = 10

        with ops.Graph().as_default():
            x = array_ops.placeholder(dtypes.float32)
            w = array_ops.placeholder(dtypes.float32)
            b = array_ops.placeholder(dtypes.float32)
            with jit_scope():
                y1 = math_ops.matmul(x, w)
            y2 = math_ops.add(y1, b)
            with jit_scope():
                y = math_ops.square(y2)

            dw = np.random.random_sample(
                (image_size, num_classes)).astype(np.float32)
            db = np.random.random_sample((num_classes)).astype(np.float32)
            dx = np.random.random_sample(
                (batch_size, image_size)).astype(np.float32)
            with session_lib.Session() as sess:
                run_metadata = config_pb2.RunMetadata()
                output = test_utils.RunWithWarmup(
                    sess,
                    y, {
                        x: dx,
                        w: dw,
                        b: db
                    },
                    run_metadata=run_metadata,
                    options=config_pb2.RunOptions(
                        trace_level=config_pb2.RunOptions.FULL_TRACE))

                # TODO(phawkins): really we would like to test that there were exactly
                # two kernel launches. However, we have no reliable way to determine
                # that.
                self.assert_(MetadataHasXlaRunOp(run_metadata))

                expected = np.square(np.dot(dx, dw) + db)
                self.assertAllClose(expected, output, rtol=1e-1)
  def testDenseLayerJitScopeUndefinedShape(self):
    """Tests that the dense layer node is properly compiled in jit scope.
    """

    with self.session() as sess:
      x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32)
      with jit_scope():
        y = layers.dense(x, 3)

      self.evaluate(variables.global_variables_initializer())
      run_metadata = config_pb2.RunMetadata()
      test_utils.RunWithWarmup(
          sess,
          y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])},
          run_metadata=run_metadata,
          options=config_pb2.RunOptions(
              trace_level=config_pb2.RunOptions.FULL_TRACE))

    labels = GetRunMetadataLabels(run_metadata)
    self.assertEqual(1, self.countXlaOps(labels))
    self.assertFalse(InLabels(labels, "MatMult"))