Esempio n. 1
0
 def test_check_jaxpr_cond_invalid(self):
     jaxpr = make_jaxpr(lambda x: lax.switch(0, [jnp.sin, jnp.cos], x))(
         1.).jaxpr
     cond = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cond')
     cond.params['branches'][0].jaxpr.invars = ()
     self.assertRaisesRegex(
         core.JaxprTypeError,
         'cond branch 0 takes 0 inputs, branch 1 takes 1',
         lambda: core.check_jaxpr(jaxpr))
Esempio n. 2
0
    def test_jaxpr_dropvar_from_loop(self):
        def f(x):
            _, y = lax.while_loop(lambda s: s[0] < 0., lambda s:
                                  (jnp.sin(s[0]), jnp.cos(s[1])), (x, x))
            return y + 1.

        jaxpr = make_jaxpr(f)(1.).jaxpr
        assert jaxpr.eqns[0].outvars[0] is core.dropvar
        core.check_jaxpr(jaxpr)
Esempio n. 3
0
 def test_jaxpr_undefined_eqn_invar(self):
     jaxpr = make_jaxpr(lambda x: jnp.sin(x) + jnp.cos(x))(1.).jaxpr
     cos = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cos')
     cos.invars[0] = core.gensym([jaxpr],
                                 suffix='_test')(cos.invars[0].aval)
     self.assertRaisesRegex(
         core.JaxprTypeError,
         r"Variable '.+_test' not defined\n\nin equation:",
         lambda: core.check_jaxpr(jaxpr))
Esempio n. 4
0
    def test_jaxpr_dropvar_from_cond(self):
        def f(x):
            _, y = lax.cond(x < 0., lambda x: (jnp.sin(x), x + 1.), lambda x:
                            (jnp.cos(x), x + 2.), x)
            return y

        jaxpr = make_jaxpr(f)(1.).jaxpr
        assert jaxpr.eqns[-1].outvars[0] is core.dropvar
        core.check_jaxpr(jaxpr)
Esempio n. 5
0
 def test_check_jaxpr_scan_correct(self):
   def f(c, x):
     b = jnp.cos(jnp.sum(jnp.sin(x)) + jnp.sum(jnp.cos(c)))
     c = jnp.sin(c * b)
     return c, b
   xs = jnp.ones((5, 3))
   c = jnp.ones(4)
   jaxpr = make_jaxpr(partial(lax.scan, f))(c, xs).jaxpr
   core.check_jaxpr(jaxpr)
Esempio n. 6
0
  def test_const(self):
    def fun(x):
      return (x, 1., np.zeros(1))

    jaxpr = api.make_jaxpr(fun)(0.)
    self.assertMultiLineStrippedEqual(str(jaxpr), """
    { lambda b ;  ; a.
        let
        in [a, 1.0, b] }
    """)
Esempio n. 7
0
    def testNormalize(self):
        def f(x):
            return x / x.sum(0)

        x = onp.arange(4.)
        expected = f(x)
        ans = _parallelize(f)(x)
        self.assertAllClose(ans, expected, check_dtypes=False)

        jaxpr = make_jaxpr(_parallelize(f))(x)
        self.assertIn('psum', repr(jaxpr))
Esempio n. 8
0
    def test_jaxpr_dropvar_from_jit_call(self):
        def inner(x):
            return x + 1, x + 2

        def f(x):
            _, y = jit(inner)(x)
            return y + 3

        jaxpr = make_jaxpr(f)(1).jaxpr
        assert jaxpr.eqns[0].outvars[0] is core.dropvar
        core.check_jaxpr(jaxpr)
Esempio n. 9
0
    def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype,
                                       dimension_numbers, bdims):
        rng = jtu.rand_small(self.rng())
        dot = partial(lax.dot_general, dimension_numbers=dimension_numbers)
        self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape),
                            (dtype, dtype), rng)

        # Checks that batching didn't introduce any transposes or broadcasts.
        jaxpr = api.make_jaxpr(dot)(np.zeros(lhs_shape, dtype),
                                    np.zeros(rhs_shape, dtype))
        for eqn in jtu.iter_eqns(jaxpr.jaxpr):
            self.assertFalse(eqn.primitive in ["transpose", "broadcast"])
Esempio n. 10
0
  def testSelect(self):
    pfun, axis_name = _papply(lax.select, 5,
                             in_axes=(None, 0, None))

    p = onp.arange(15).reshape((5, 3)) % 4 == 1
    t = onp.ones((5, 3))
    f = onp.zeros((5, 3))
    jaxpr = make_jaxpr(pfun)(p, t[0], f)

    def expected_spmd(p, t, f):
      return lax.select(
          lax_parallel.psplit_like(p, t, axis_name),
          t,
          lax_parallel.psplit_like(f, t, axis_name))

    expected_jaxpr = make_jaxpr(expected_spmd)(p, t[0], f)
    assert repr(jaxpr) == repr(expected_jaxpr)

    ans = _serial_pmap(pfun, axis_name, in_axes=(None, 0, None))(p, t, f)
    expected = lax.select(p, t, f)
    self.assertAllClose(ans, expected, check_dtypes=True)
Esempio n. 11
0
    def test_grad_simple(self):
        def func(x):
            y = hcb.id_print(x * 2.,
                             what="x * 2",
                             output_stream=testing_stream)
            return x * hcb.id_print(
                y * 3., what="y * 3", output_stream=testing_stream)

        grad_func = api.grad(func)
        assertMultiLineStrippedEqual(
            self, """
{ lambda  ; a.
  let b = mul 1.00 a
      c d = id_tap[ arg_treedef=*
                    func=_print
                    nr_untapped=1
                    transforms=(('jvp',), ('transpose',))
                    what=y * 3 ] b 0.00
      e = mul c 3.00
      f g = id_tap[ arg_treedef=*
                    func=_print
                    nr_untapped=1
                    transforms=(('jvp',), ('transpose',))
                    what=x * 2 ] e 0.00
      h = mul f 2.00
      i = mul a 2.00
      j = id_tap[ arg_treedef=*
                  func=_print
                  nr_untapped=0
                  what=x * 2 ] i
      k = mul j 3.00
      l = id_tap[ arg_treedef=*
                  func=_print
                  nr_untapped=0
                  what=y * 3 ] k
      m = mul 1.00 l
      n = add_any h m
  in (n,) }""", str(api.make_jaxpr(grad_func)(5.)))

        with hcb.outfeed_receiver():
            res_grad = grad_func(jnp.float32(5.))
        self.assertAllClose(2. * 5. * 6., res_grad, check_dtypes=False)
        assertMultiLineStrippedEqual(
            self, """
what: x * 2
10.00
what: y * 3
30.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: y * 3
5.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
15.00""", testing_stream.output)
        testing_stream.reset()
Esempio n. 12
0
    def test_while_cond(self, with_jit=False):
        def func(x):
            x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
            x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)

            def body(x):
                x3 = hcb.id_print(x,
                                  where="w_b_1",
                                  output_stream=testing_stream)
                x4 = lax.cond(
                    x % 2 == 0, x3 + 1, lambda x: hcb.id_print(
                        x, where="w_b_t", output_stream=testing_stream),
                    x3 + 1,
                    lambda x: hcb.id_print(-1,
                                           where="w_b_f",
                                           result=x,
                                           output_stream=testing_stream))
                return hcb.id_print(x4,
                                    where="w_b_2",
                                    output_stream=testing_stream)

            x10 = lax.while_loop(lambda x: x <= 3, body, x2)
            res = hcb.id_print(x10, where="end", output_stream=testing_stream)
            return res

        logging.warning("%s: %s", self._testMethodName,
                        api.make_jaxpr(func)(1))
        logging.warning("%s: %s", self._testMethodName,
                        api.xla_computation(func)(1).as_hlo_text())
        transform = api.jit if with_jit else lambda f: f
        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            self.assertEqual(4, transform(func)(1))
        assertMultiLineStrippedEqual(
            self, """
where: 1
1
where: 2
2
where: w_b_1
2
where: w_b_t
3
where: w_b_2
3
where: w_b_1
3
where: w_b_f
-1
where: w_b_2
4
where: end
4""", testing_stream.output)
        testing_stream.reset()
Esempio n. 13
0
 def assertRewrite(self,
                   expected: str,
                   func: Callable,
                   args: Sequence,
                   has_input_token=True,
                   has_output_token=True):
     """Check that the rewrite of func(*args) matches expected."""
     jaxpr = api.make_jaxpr(func)(*args)
     # TODO: re-enable when we change the host_callback rewriter
     #rewritten = hcb._rewrite_typed_jaxpr(jaxpr,
     #                                     has_input_token, has_output_token)
     #assertMultiLineStrippedEqual(self, expected, str(rewritten))
     del jaxpr
Esempio n. 14
0
 def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype,
                                         dimension_numbers, rng_factory):
   rng = rng_factory(self.rng())
   lhs = rng(lhs_shape, dtype)
   rhs = rng(rhs_shape, dtype)
   dot_general = partial(lax.dot_general, dimension_numbers=dimension_numbers,
                         precision=lax.Precision.HIGHEST)
   check_grads_bilinear(dot_general, (lhs, rhs), order=2, modes=["fwd", "rev"])
   # check that precision config is preserved
   result, pullback = api.vjp(dot_general, lhs, rhs)
   gresult = lax.zeros_like_array(result)
   s = str(api.make_jaxpr(pullback)(gresult))
   assert "precision=HIGHEST" in s
Esempio n. 15
0
 def assertRewrite(self,
                   expected: str,
                   func: Callable,
                   args: Sequence,
                   has_input_token=True,
                   has_output_token=True):
     """Check that the rewrite of func(*args) matches expected."""
     jaxpr = api.make_jaxpr(func)(*args)
     assertMultiLineStrippedEqual(
         self, expected,
         str(
             hcb._rewrite_typed_jaxpr(jaxpr, has_input_token,
                                      has_output_token)[0]))
Esempio n. 16
0
 def testDotGrad(self, lhs_shape, rhs_shape, dtype, rng_factory):
   rng = rng_factory(self.rng())
   tol = {onp.float16: 1e-1, onp.float32: 1e-4}
   lhs = rng(lhs_shape, dtype)
   rhs = rng(rhs_shape, dtype)
   dot = partial(lax.dot, precision=lax.Precision.HIGHEST)
   check_grads_bilinear(dot, (lhs, rhs), order=2, modes=["fwd", "rev"],
                        atol=tol, rtol=tol)
   # check that precision config is preserved
   result, pullback = api.vjp(dot, lhs, rhs)
   gresult = lax.zeros_like_array(result)
   s = str(api.make_jaxpr(pullback)(gresult))
   assert "precision=HIGHEST" in s
Esempio n. 17
0
    def testNestedBatchingMatMat(self):
        matvec = vmap(np.vdot, in_axes=(0, None))
        matmat = vmap(matvec, in_axes=(None, 1), out_axes=1)

        R = onp.random.RandomState(0).randn
        A = R(4, 3)
        B = R(3, 2)

        ans = matmat(A, B)
        expected = onp.dot(A, B)
        self.assertAllClose(ans, expected, check_dtypes=False)

        jaxpr = make_jaxpr(matmat)(A, B)
        self.assertEqual(len(jaxpr.eqns), 1)
Esempio n. 18
0
  def test_jarrett_jvps(self):
    def f1(x):
      return np.sin(np.sin(np.sin(x)))
    f2 = api.jarrett(f1)

    for x in [3., onp.array([2., 3., 4.])]:
      self.assertAllClose(f1(x), f2(x), check_dtypes=True)

      _, f1_vjp = api.vjp(f1, x)
      _, f2_vjp = api.vjp(f2, x)
      self.assertAllClose(f1_vjp(x), f2_vjp(x), check_dtypes=True)

      jaxpr2 = api.make_jaxpr(f2_vjp)(x)
      assert len(jaxpr2.constvars) == 1
Esempio n. 19
0
    def test_jvp(self):
        jvp_fun1 = lambda x, xt: api.jvp(fun1, (x, ), (xt, ))
        assertMultiLineStrippedEqual(
            self, """
{ lambda  ; a b.
  let c = mul a 2.00
      d = id_tap[ arg_treedef=*
                  func=_print
                  nr_untapped=0
                  what=a * 2 ] c
      e = mul d 3.00
      f g = id_tap[ arg_treedef=*
                    func=_print
                    nr_untapped=1
                    what=y * 3 ] e d
      h = mul g g
      i = mul b 2.00
      j k = id_tap[ arg_treedef=*
                    func=_print
                    nr_untapped=1
                    transforms=('jvp',)
                    what=a * 2 ] i d
      l = mul j 3.00
      m n o = id_tap[ arg_treedef=*
                      func=_print
                      nr_untapped=2
                      transforms=('jvp',)
                      what=y * 3 ] l j f
      p = mul n g
      q = mul g n
      r = add_any p q
  in (h, r) }""",
            str(api.make_jaxpr(jvp_fun1)(jnp.float32(5.), jnp.float32(0.1))))
        with hcb.outfeed_receiver():
            res_primals, res_tangents = jvp_fun1(jnp.float32(5.),
                                                 jnp.float32(0.1))
        self.assertAllClose(100., res_primals, check_dtypes=False)
        self.assertAllClose(4., res_tangents, check_dtypes=False)
        assertMultiLineStrippedEqual(
            self, """
what: a * 2
10.00
transforms: ('jvp',) what: a * 2
0.20
what: y * 3
30.00
transforms: ('jvp',) what: y * 3
0.60""", testing_stream.output)
        testing_stream.reset()
Esempio n. 20
0
  def test_jarrett_jvps2(self):
    def f1(x, y):
      return np.sin(x) * np.cos(y) * np.sin(x) * np.cos(y)
    f2 = api.jarrett(f1)

    # TODO(mattjj): doesn't work for (3., onp.array([4., 5.]))
    for x, y in [(3., 4.), (onp.array([5., 6.]), onp.array([7., 8.]))]:
      self.assertAllClose(f1(x, y), f2(x, y), check_dtypes=True)

      _, f1_vjp = api.vjp(f1, x, y)
      _, f2_vjp = api.vjp(f2, x, y)
      self.assertAllClose(f1_vjp(y), f2_vjp(y), check_dtypes=True)

      jaxpr2 = api.make_jaxpr(f2_vjp)(y)
      assert len(jaxpr2.constvars) == 2
Esempio n. 21
0
  def testNestedBatchingMatMat(self):
    matvec = vmap(jnp.vdot, in_axes=(0, None))
    matmat = vmap(matvec, in_axes=(None, 1), out_axes=1)

    R = np.random.RandomState(0).randn
    A = R(4, 3)
    B = R(3, 2)

    ans = matmat(A, B)
    expected = np.dot(A, B)
    self.assertAllClose(
        ans, expected, check_dtypes=False,
        rtol={np.float32:1e-2} if jtu.device_under_test() == "tpu" else None)

    jaxpr = make_jaxpr(matmat)(A, B)
    self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
Esempio n. 22
0
File: api_test.py Progetto: yyht/jax
    def test_partial_eval_lower(self):
        # this is a simplified model of a bug that arose when we first used @jit in
        # a jvp rule. it's in this file because we want to use make_jaxpr.
        @api.jit
        def f(a, b, c):
            a = lax.broadcast(a, (2, ))
            return lax.select(a, b, c)

        a = onp.ones((3, 3), dtype=onp.bool_)
        b = onp.ones((2, 3, 3))
        c = onp.ones((2, 3, 3))

        jaxpr = api.make_jaxpr(lambda b, c: f(a, b, c))(b, c)
        subjaxpr = next(eqn.bound_subjaxprs[0][0] for eqn in jaxpr.eqns
                        if eqn.bound_subjaxprs)
        self.assertEqual(len(subjaxpr.eqns), 1)
Esempio n. 23
0
  def test_jit_sequence1(self):
    def func(x):
      x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
      return hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)

    logging.info("%s: %s", self._testMethodName,
          api.make_jaxpr(func)(1))
    logging.info("%s: %s", self._testMethodName,
          api.xla_computation(func)(1).as_hlo_text())

    with hcb.outfeed_receiver(receiver_name=self._testMethodName):
      self.assertEqual(2, api.jit(func)(1))
    assertMultiLineStrippedEqual(self, """
where: 1
1
where: 2
2""", testing_stream.output)
    testing_stream.reset()
Esempio n. 24
0
  def test_with_tuple_results(self):
    def func2(x):
      x1, y1 = hcb.id_print((x * 2., x * 3.), output_stream=testing_stream)
      return x1 + y1

    assertMultiLineStrippedEqual(self, """
{ lambda  ; a.
  let b = mul a 2.00
      c = mul a 3.00
      d e = id_tap[ arg_treedef=PyTreeDef(tuple, [*,*])
                    func=_print
                    ] b c
      f = add d e
  in (f,) }""", str(api.make_jaxpr(func2)(3.)))
    with hcb.outfeed_receiver():
      self.assertEqual(3. * (2. + 3.), func2(3.))
    assertMultiLineStrippedEqual(self, """
[ 6.00
  9.00 ]""", testing_stream.output)
    testing_stream.reset()
Esempio n. 25
0
  def test_jit_constant(self):
    def func(x):
      return hcb.id_print(42, result=x, output_stream=testing_stream)

    assertMultiLineStrippedEqual(self, """
{ lambda  ; a.
  let b = xla_call[ backend=None
                    call_jaxpr={ lambda  ; a.
                                 let b c = id_tap[ arg_treedef=*
                                                   func=_print
                                                   nr_untapped=1
                                                   ] 42 a
                                 in (c,) }
                    device=None
                    name=func ] a
  in (b,) }""", str(api.make_jaxpr(api.jit(func))(5)))
    self.assertEqual("", testing_stream.output)

    with hcb.outfeed_receiver():
      self.assertAllClose(5, api.jit(func)(5), check_dtypes=True)
    assertMultiLineStrippedEqual(self, """
42""", testing_stream.output)
    testing_stream.reset()
Esempio n. 26
0
  def test_vmap_not_batched(self):
    x = 3.
    def func(y):
      # x is not mapped, y is mapped
      _, y = hcb.id_print((x, y), output_stream=testing_stream)
      return x + y

    vmap_func = api.vmap(func)
    vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
    assertMultiLineStrippedEqual(self, """
{ lambda  ; a.
  let b c = id_tap[ arg_treedef=PyTreeDef(tuple, [*,*])
                    func=_print
                    transforms=(('batch', (None, 0)),) ] 3.00 a
      d = add c 3.00
  in (d,) }""", str(api.make_jaxpr(vmap_func)(vargs)))
    with hcb.outfeed_receiver():
      res_vmap = vmap_func(vargs)
    assertMultiLineStrippedEqual(self, """
transforms: ({'name': 'batch', 'batch_dims': (None, 0)},)
[ 3.00
  [4.00 5.00] ]""", testing_stream.output)
    testing_stream.reset()
Esempio n. 27
0
 def test_cond(self):
   def f(x):
     return lax.cond(x >= 0.,
                     x + 1.,
                     lambda xt: xt + x,
                     x + 2.,
                     lambda xf: xf - x)
   jaxpr = api.make_jaxpr(f)(3.)
   self.assertMultiLineStrippedEqual(str(jaxpr), """
   { lambda  ;  ; a.
     let b = ge a 0.0
         c = add a 1.0
         d = add a 2.0
         e = cond[ false_jaxpr={ lambda  ;  ; b a.
                                 let c = sub a b
                                 in [c] }
                   false_nconsts=1
                   true_jaxpr={ lambda  ;  ; b a.
                                let c = add a b
                                in [c] }
                   true_nconsts=1 ] b a c a d
     in [e] }
       """)
Esempio n. 28
0
  def test_mask(self):
    # TODO(necula)
    raise SkipTest("masking has regressed")
    @partial(api.mask, in_shapes=['n'], out_shape='')
    def padded_sum(x):
      return jnp.sum(hcb.id_print(x, what="x", output_stream=testing_stream))
    args = [jnp.arange(4)], dict(n=np.int64(2))
    assertMultiLineStrippedEqual(self, """
{ lambda c f ; a b.
  let d = lt c b
      e = id_tap[ func=_print
                  logical_shapes=[(Traced<ShapedArray(int32[]):JaxprTrace(level=0/0)>,)]
                  transforms=('mask',)
                  what=x ] a
      g = select d e f
      h = reduce_sum[ axes=(0,) ] g
  in (h,) }""", str(api.make_jaxpr(padded_sum)(*args)))

    _ = padded_sum(*args)
    self.assertMultiLineStrippedEqual("""
logical_shapes: [(2,)] transforms: ('mask',) what: x
[0 1 2 3]
   """, testing_stream.output)
    testing_stream.reset()
Esempio n. 29
0
    def test_grad_double(self):
        def func(x):
            y = hcb.id_print(x * 2.,
                             what="x * 2",
                             output_stream=testing_stream)
            return x * (y * 3.)

        grad_func = api.grad(api.grad(func))
        with hcb.outfeed_receiver():
            assertMultiLineStrippedEqual(
                self, """
{ lambda  ; a.
  let
  in (12.00,) }""", str(api.make_jaxpr(grad_func)(5.)))
            # Just making the Jaxpr invokes the id_print twiceonce
            assertMultiLineStrippedEqual(
                self, """
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
3.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}, {'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
2.00""", testing_stream.output)
            testing_stream.reset()
            res_grad = grad_func(jnp.float32(5.))

        self.assertAllClose(12., res_grad, check_dtypes=False)
        assertMultiLineStrippedEqual(
            self, """
what: x * 2
10.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
15.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}, {'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
2.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
3.00""", testing_stream.output)
        testing_stream.reset()
Esempio n. 30
0
  def test_eval(self):
    assertMultiLineStrippedEqual(self, """
{ lambda  ; a.
  let b = mul a 2.00
      c = id_tap[ arg_treedef=*
                  func=_print
                  what=a * 2 ] b
      d = mul c 3.00
      e f = id_tap[ arg_treedef=*
                    func=_print
                    nr_untapped=1
                    what=y * 3 ] d c
      g = integer_pow[ y=2 ] f
  in (g,) }""", str(api.make_jaxpr(fun1)(5.)))
    self.assertEqual("", testing_stream.output)

    with hcb.outfeed_receiver():
      self.assertAllClose((5. * 2.) ** 2, fun1(5.), check_dtypes=True)
    assertMultiLineStrippedEqual(self, """
what: a * 2
10.00
what: y * 3
30.00""", testing_stream.output)
    testing_stream.reset()