示例#1
0
文件: api_test.py 项目: yyht/jax
    def test_xla_computation(self):
        # these tests basically check the examples in the xla_computation docstring

        def h(x):
            return np.sin(np.cos(x))

        c = api.xla_computation(h)(2.)
        self.assertIn('cosine', c.GetHloText())
        self.assertIn('sine', c.GetHloText())

        def f(x):
            return x - lax.psum(x, 'i')

        axis_env = [('i', 4)]
        c = api.xla_computation(f, axis_env=axis_env)(2)
        self.assertIn('all-reduce', c.GetHloText())
        self.assertIn('replica_groups={{0,1,2,3}}', c.GetHloText())

        def g(x):
            rowsum = lax.psum(x, 'i')
            colsum = lax.psum(x, 'j')
            allsum = lax.psum(x, ('i', 'j'))
            return rowsum, colsum, allsum

        axis_env = [('i', 4), ('j', 2)]
        c = api.xla_computation(g, axis_env=axis_env)(5.)
        self.assertIn('all-reduce', c.GetHloText())
        self.assertIn('replica_groups={{0,2,4,6},{1,3,5,7}}', c.GetHloText())
        self.assertIn('replica_groups={{0,1},{2,3},{4,5},{6,7}}',
                      c.GetHloText())
        self.assertIn('replica_groups={{0,1,2,3,4,5,6,7}}', c.GetHloText())
示例#2
0
    def test_xla_computation_args(self):
        def foo(x, y, z):
            return x + y + z

        c = api.xla_computation(foo)(1., 2., 3.)
        self.assertEqual(len(c.GetProgramShape().parameter_shapes()), 3)

        c = api.xla_computation(foo, tuple_args=True)(1., 2., 3.)
        param_shapes = c.GetProgramShape().parameter_shapes()
        self.assertEqual(len(param_shapes), 1)
        self.assertEqual(param_shapes[0].xla_element_type(),
                         xb.xla_client.PrimitiveType.TUPLE)
示例#3
0
    def test_scan_cond(self, with_jit=False):
        def func(x):
            x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
            x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)

            def body(c, x):
                x3 = hcb.id_print(x, where="s_1", output_stream=testing_stream)
                x4 = lax.cond(
                    x % 2 == 0, x3 + 1, lambda x: hcb.id_print(
                        x, where="s_t", output_stream=testing_stream), x3 + 1,
                    lambda x: hcb.id_print(-1,
                                           where="s_f",
                                           result=x,
                                           output_stream=testing_stream))
                return (c,
                        hcb.id_print(x4,
                                     where="s_2",
                                     output_stream=testing_stream))

            _, x10 = lax.scan(body, x2, jnp.arange(3))
            res = hcb.id_print(x10, where="10", output_stream=testing_stream)
            return res

        logging.warning("%s: %s", self._testMethodName,
                        api.make_jaxpr(func)(1))
        logging.warning("%s: %s", self._testMethodName,
                        api.xla_computation(func)(1).as_hlo_text())

        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            if with_jit:
                func = api.jit(func)
            res = func(1)
            self.assertAllClose(jnp.array([1, 2, 3]), res, check_dtypes=True)
        assertMultiLineStrippedEqual(
            self, """
where: 1
1
where: 2
2
where: s_1
0
where: s_t
1
where: s_2
1
where: s_1
1
where: s_f
-1
where: s_2
2
where: s_1
2
where: s_t
3
where: s_2
3
where: 10
[1 2 3]""", testing_stream.output)
        testing_stream.reset()
示例#4
0
    def test_jit_interleaving(self):
        # Several jit's without data dependencies; they may interfere
        count = 0  # Count tap invocations
        nr_arrays = 5

        def tap_func(arg, **kwargs):
            nonlocal count
            assert len(arg) == nr_arrays
            count += 1

        # This is the function that we'll run multiple times
        def func(x, count):
            for i in range(count):
                x = hcb.id_tap(tap_func, [x + i for i in range(nr_arrays)],
                               i=i)[-1]
            return x

        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            x = jnp.array(1, dtype=np.int32)
            res = 0
            for i in range(10):
                # No dependencies between the jit invocations
                res += api.jit(lambda x: func(x, 10))(x)
        logging.warning(
            "%s: %s", self._testMethodName,
            api.xla_computation(lambda x: func(x, 5))(1).as_hlo_text())
        self.assertEqual(100, count)
示例#5
0
    def test_cond(self, with_jit=False):
        """A conditional"""
        def func(x):
            x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
            x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)

            x4 = lax.cond(
                x % 2 == 0, x2 + 1, lambda x: hcb.id_print(
                    x, where="cond_t", output_stream=testing_stream), x2 + 1,
                lambda x: hcb.id_print(
                    -1, where="cond_f", result=x, output_stream=testing_stream)
            )
            x5 = hcb.id_print(x4 + 1,
                              where="end",
                              output_stream=testing_stream)
            return x5

        logging.warning("%s: %s", self._testMethodName,
                        api.make_jaxpr(func)(1))
        logging.warning("%s: %s", self._testMethodName,
                        api.xla_computation(func)(1).as_hlo_text())
        transform = api.jit if with_jit else lambda f: f
        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            self.assertEqual(4, transform(func)(1))
        assertMultiLineStrippedEqual(
            self, """
where: 1
1
where: 2
2
where: cond_f
-1
where: end
4""", testing_stream.output)
        testing_stream.reset()
示例#6
0
    def test_jit_while_pred_printing(self):
        """While with printing in the conditional."""
        raise SkipTest("Not yet implemented")

        #TODO: implement printing inside conditional
        def func(x):
            x1 = hcb.id_print(x, where="1")

            def body(x):
                x3 = hcb.id_print(x, where="w_1", output_stream=testing_stream)
                return hcb.id_print(x3 + 1,
                                    where="w_2",
                                    output_stream=testing_stream)

            x10 = lax.while_loop(
                lambda x: hcb.id_print(
                    x < 10, where="w_p", output_stream=testing_stream), body,
                x1)
            res = hcb.id_print(x10, where="10", output_stream=testing_stream)
            return res

        logging.warning("%s: %s", self._testMethodName,
                        api.make_jaxpr(func)(1))
        logging.warning("%s: %s", self._testMethodName,
                        api.xla_computation(func)(1).as_hlo_text())

        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            self.assertEqual(10, api.jit(func)(1))
        assertMultiLineStrippedEqual(self, """
""", testing_stream.output)
        testing_stream.reset()
示例#7
0
    def test_jit_nested(self):
        def func(x):
            x1 = hcb.id_print(x, where="1", output_stream=testing_stream)

            def func_nested(x):
                x2 = hcb.id_print(x + 1,
                                  where="nested",
                                  output_stream=testing_stream)
                return x2

            x3 = api.jit(func_nested)(x1)
            return hcb.id_print(x3 + 1,
                                where="3",
                                output_stream=testing_stream)

        logging.warning("%s: %s", self._testMethodName,
                        api.make_jaxpr(func)(1))
        logging.warning("%s: %s", self._testMethodName,
                        api.xla_computation(func)(1).as_hlo_text())
        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            self.assertEqual(3, api.jit(func)(1))
        assertMultiLineStrippedEqual(
            self, """
where: 1
1
where: nested
2
where: 3
3""", testing_stream.output)
        testing_stream.reset()
示例#8
0
  def test_xla_computation_instantiate_constant_outputs(self):
    def f():
      return np.zeros((3, 4))

    xla_comp = api.xla_computation(f, instantiate_const_outputs=True)()
    out_shape, = xla_comp.GetReturnValueShape().tuple_shapes()
    self.assertEqual(out_shape.dimensions(), (3, 4))
示例#9
0
    def test_jit_nested_cond_no_print(self):
        """A nested conditional, without any prints"""
        raise SkipTest("skip this")

        @api.jit
        def cfun(x):
            return lax.cond(
                lax.lt(x, 2), lambda x: x,
                lambda x: lax.cond(x < 5, 3, lambda x: x, 4, lambda y: y), x)

        print(self._testMethodName, api.xla_computation(cfun)(1).as_hlo_text())
        cfun(1)
示例#10
0
    def test_while_cond(self, with_jit=False):
        def func(x):
            x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
            x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)

            def body(x):
                x3 = hcb.id_print(x,
                                  where="w_b_1",
                                  output_stream=testing_stream)
                x4 = lax.cond(
                    x % 2 == 0, x3 + 1, lambda x: hcb.id_print(
                        x, where="w_b_t", output_stream=testing_stream),
                    x3 + 1,
                    lambda x: hcb.id_print(-1,
                                           where="w_b_f",
                                           result=x,
                                           output_stream=testing_stream))
                return hcb.id_print(x4,
                                    where="w_b_2",
                                    output_stream=testing_stream)

            x10 = lax.while_loop(lambda x: x <= 3, body, x2)
            res = hcb.id_print(x10, where="end", output_stream=testing_stream)
            return res

        logging.warning("%s: %s", self._testMethodName,
                        api.make_jaxpr(func)(1))
        logging.warning("%s: %s", self._testMethodName,
                        api.xla_computation(func)(1).as_hlo_text())
        transform = api.jit if with_jit else lambda f: f
        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            self.assertEqual(4, transform(func)(1))
        assertMultiLineStrippedEqual(
            self, """
where: 1
1
where: 2
2
where: w_b_1
2
where: w_b_t
3
where: w_b_2
3
where: w_b_1
3
where: w_b_f
-1
where: w_b_2
4
where: end
4""", testing_stream.output)
        testing_stream.reset()
示例#11
0
  def testIssue810(self):
    def loss(A):
      def step(x, i):
        return np.matmul(A, x), None
      init_x = np.zeros(A.shape[-1:])
      last_x, _ = lax.scan(step, init_x, np.arange(10))
      return np.sum(last_x)

    A = np.zeros((3, 3))
    # The second DUS was unnecessarily replicating A across time.
    # We check XLA because _scan_impl is "underneath" the jaxpr language.
    s = str(api.xla_computation(api.grad(loss))(A).GetHloText())
    assert s.count("dynamic-update-slice(") < 2
示例#12
0
    def test_jit_simple(self):
        jit_fun1 = api.jit(lambda x: 3. * hcb.id_print(
            2. * x, what="here", output_stream=testing_stream))

        logging.warning("%s: %s", self._testMethodName,
                        api.xla_computation(jit_fun1)(5.).as_hlo_text())
        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            res = jit_fun1(5.)

        self.assertAllClose(6. * 5., res, check_dtypes=True)
        assertMultiLineStrippedEqual(self, """
what: here
10.00""", testing_stream.output)
        testing_stream.reset()
示例#13
0
  def test_jit_sequence1(self):
    def func(x):
      x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
      return hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)

    logging.info("%s: %s", self._testMethodName,
          api.make_jaxpr(func)(1))
    logging.info("%s: %s", self._testMethodName,
          api.xla_computation(func)(1).as_hlo_text())

    with hcb.outfeed_receiver(receiver_name=self._testMethodName):
      self.assertEqual(2, api.jit(func)(1))
    assertMultiLineStrippedEqual(self, """
where: 1
1
where: 2
2""", testing_stream.output)
    testing_stream.reset()
示例#14
0
文件: api_test.py 项目: yyht/jax
    def test_staging_out_multi_replica(self):
        def f(x):
            return api.pmap(np.mean)(x)

        xla_comp = api.xla_computation(f)
        xla_comp(np.arange(8)).GetHloText()  # doesn't crash