Esempio n. 1
0
  def test_staging_nested_including_shape_arg(self):
    # This test covers the _get_tracers_only_in_shapes logic in partial_eval.py.
    n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
    a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
    b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)

    @lu.wrap_init
    def f(x, y):
      @jax.jit
      def g(_, x, y, z, w):
        return (x, w)
      return g(x.shape[0], x, y, x, y)

    jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
                                            keep_inputs=[False, True, True])

    self.assertLen(jaxpr.eqns, 1)
    eqn = jaxpr.eqns[0]
    self.assertIsInstance(eqn.primitive, core.CallPrimitive)
    inner_jaxpr = eqn.params['call_jaxpr']
    self.assertIsInstance(inner_jaxpr, core.Jaxpr)

    self.assertLen(inner_jaxpr.invars, 1 + 4)  # one axis size var
    self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[1].aval.shape)
    self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[2].aval.shape)
    self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[3].aval.shape)
    self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[4].aval.shape)
Esempio n. 2
0
    def test_staging_primitive_applications(self):
        n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
        a = core.DShapedArray((pe.DBIdx(0), ),
                              jnp.dtype('float32'),
                              weak_type=False)
        b = core.DShapedArray((pe.DBIdx(0), ),
                              jnp.dtype('float32'),
                              weak_type=False)

        @lu.wrap_init
        def f(x, y):
            z = lax.mul(x, y)
            w = lax.sin(z)
            u = lax_internal._reduce_sum(w, [0])
            return (u, )

        jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
            f, [n, a, b], keep_inputs=[False, True, True])

        self.assertLen(jaxpr.invars,
                       1 + 2)  # one axis size var, two other inputs
        self.assertLen(jaxpr.eqns, 3)
        self.assertLen(jaxpr.eqns[0].outvars, 1)
        self.assertEqual(jaxpr.eqns[0].outvars[0].aval.shape,
                         jaxpr.invars[1].aval.shape)

        self.assertLen(jaxpr.outvars, 1)
        self.assertEqual(jaxpr.outvars[0].aval.shape, ())
Esempio n. 3
0
  def test_staging_nested(self):
    n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
    a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
    b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)

    @lu.wrap_init
    def f(x, y):
      @jax.jit
      def g(x, y, z, w):
        return (x, w)
      return g(x, y, x, y)

    jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
                                            keep_inputs=[False, True, True])

    self.assertLen(jaxpr.invars, 1 + 2)  # one axis size var, two other inputs
    self.assertEqual((jaxpr.invars[0],), jaxpr.invars[1].aval.shape)
    self.assertEqual((jaxpr.invars[0],), jaxpr.invars[2].aval.shape)

    self.assertLen(jaxpr.outvars, 2)
    self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[0].aval.shape)
    self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[1].aval.shape)

    self.assertLen(jaxpr.eqns, 1)
    eqn = jaxpr.eqns[0]
    self.assertIsInstance(eqn.primitive, core.CallPrimitive)
    inner_jaxpr = eqn.params['call_jaxpr']
    self.assertIsInstance(inner_jaxpr, core.Jaxpr)

    self.assertLen(inner_jaxpr.invars, 1 + 4)  # one axis size var
    self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[1].aval.shape)
    self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[2].aval.shape)
    self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[3].aval.shape)
    self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[4].aval.shape)
Esempio n. 4
0
    def test_typecheck_staging_nested(self):
        n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
        m = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
        a = core.DShapedArray((DBIdx(0), ),
                              jnp.dtype('float32'),
                              weak_type=False)
        b = core.DShapedArray((DBIdx(1), ),
                              jnp.dtype('float32'),
                              weak_type=False)

        @lu.wrap_init
        def f(a, b):
            @jax.jit
            def g(x):
                return x

            return g(a),

        jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
            f, [n, m, a, b], keep_inputs=[False, False, True, True])
        # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let
        #     e:f32[a] = xla_call[
        #       call_jaxpr={ lambda ; f:i32[] g:f32[f]. let  in (g,) }
        #       name=g
        #     ] a c
        #   in (e,) }
        core.check_jaxpr(jaxpr)  # no problems here...

        # Let's introduce a type error by applying the called jaxpr to arguments
        # with types which aren't consistent with its input binders:
        _, _, c, d = jaxpr.invars
        jaxpr.eqns[0].invars[1] = d
        # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let
        #     e:f32[a] = xla_call[
        #       call_jaxpr={ lambda ; f:i32[] g:f32[f]. let  in (g,) }
        #       name=g
        #     ] a d   !!! type error here !!!
        #   in (e,) }
        with self.assertRaisesRegex(TypeError, "passes operand"):
            core.check_jaxpr(jaxpr)

        # Restore the original jaxpr:
        jaxpr.eqns[0].invars[1] = c
        core.check_jaxpr(jaxpr)  # no problems here...

        # Let's introduce another type error by setting the call result let binders
        # to have the wrong type:
        jaxpr.eqns[0].outvars[0] = core.Var(0, '', d.aval)
        # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let
        #     e:f32[b] = xla_call[   !!! type error here !!!
        #       call_jaxpr={ lambda ; f:i32[] g:f32[f]. let  in (g,) }
        #       name=g
        #     ] a c
        #   in (h,) }
        with self.assertRaisesRegex(TypeError, "inconsistently typed as"):
            core.check_jaxpr(jaxpr)
Esempio n. 5
0
 def make_aval(arg, spec):
   if not spec:
     return shaped_abstractify(arg)
   assert all(arg.shape[i] == sizes.setdefault(name, arg.shape[i])
              for i, name in spec.items())
   shape = [env[spec[i]] if i in spec else d for i, d in enumerate(arg.shape)]
   return core.DShapedArray(tuple(shape), arg.dtype, False)
Esempio n. 6
0
    def test_staging_nested_including_shape_arg(self):
        n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
        a = core.DShapedArray((pe.DBIdx(0), ),
                              jnp.dtype('float32'),
                              weak_type=False)
        b = core.DShapedArray((pe.DBIdx(0), ),
                              jnp.dtype('float32'),
                              weak_type=False)

        @lu.wrap_init
        def f(x, y):
            @jax.jit
            def g(_, x, y, z, w):
                return (x, w)

            return g(x.shape[0], x, y, x, y)

        jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
            f, [n, a, b], keep_inputs=[False, True, True])
        print(jaxpr)

        # { lambda ; a:i32[] b:f32[a] c:f32[a]. let
        #     d:f32[a] e:f32[a] = xla_call[
        #       call_jaxpr={ lambda ; f:i32[] g:i32[] h:f32[f] i:f32[f] j:f32[f] k:f32[f]. let
        #
        #         in (h, k) }
        #       name=g
        #     ] a a b c b c
        #   in (d, e) }

        self.assertLen(jaxpr.eqns, 1)
        eqn = jaxpr.eqns[0]
        self.assertIsInstance(eqn.primitive, core.CallPrimitive)
        inner_jaxpr = eqn.params['call_jaxpr']
        self.assertIsInstance(inner_jaxpr, core.Jaxpr)

        self.assertLen(inner_jaxpr.invars, 1 + 4)  # one axis size var
        self.assertEqual((inner_jaxpr.invars[0], ),
                         inner_jaxpr.invars[1].aval.shape)
        self.assertEqual((inner_jaxpr.invars[0], ),
                         inner_jaxpr.invars[2].aval.shape)
        self.assertEqual((inner_jaxpr.invars[0], ),
                         inner_jaxpr.invars[3].aval.shape)
        self.assertEqual((inner_jaxpr.invars[0], ),
                         inner_jaxpr.invars[4].aval.shape)
Esempio n. 7
0
  def test_staging_basic(self):
    n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
    a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
    b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)

    @lu.wrap_init
    def f(x, y):
      return x, y

    jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
                                            keep_inputs=[False, True, True])

    self.assertLen(jaxpr.invars, 3)
    self.assertEqual((jaxpr.invars[0],), jaxpr.invars[1].aval.shape)
    self.assertEqual((jaxpr.invars[0],), jaxpr.invars[2].aval.shape)

    self.assertLen(jaxpr.outvars, 2)
    self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[0].aval.shape)
    self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[1].aval.shape)
Esempio n. 8
0
File: utils.py Progetto: wayfeng/jax
def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
                           named_shape_rule, *avals, **kwargs):
    assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals
    assert not prim.multiple_results
    weak_type = weak_type_rule(*avals, **kwargs)
    least_specialized = _max(
        map(type, avals), key=operator.attrgetter('array_abstraction_level'))
    if least_specialized is core.ConcreteArray:
        out = prim.impl(*[x.val for x in avals], **kwargs)
        return core.ConcreteArray(out.dtype, out, weak_type=weak_type)
    elif least_specialized is core.ShapedArray:
        return core.ShapedArray(shape_rule(*avals, **kwargs),
                                dtype_rule(*avals, **kwargs),
                                weak_type=weak_type,
                                named_shape=named_shape_rule(*avals, **kwargs))
    elif least_specialized is core.DShapedArray:
        return core.DShapedArray(shape_rule(*avals, **kwargs),
                                 dtype_rule(*avals, **kwargs), weak_type)
    elif least_specialized is core.UnshapedArray:
        return core.UnshapedArray(dtype_rule(*avals, **kwargs),
                                  weak_type=weak_type)
    else:
        raise TypeError(avals, least_specialized)