コード例 #1
0
ファイル: core_test.py プロジェクト: xueeinstein/jax
    def test_typecheck_staging_nested(self):
        n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
        m = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
        a = core.DShapedArray((DBIdx(0), ),
                              jnp.dtype('float32'),
                              weak_type=False)
        b = core.DShapedArray((DBIdx(1), ),
                              jnp.dtype('float32'),
                              weak_type=False)

        @lu.wrap_init
        def f(a, b):
            @jax.jit
            def g(x):
                return x

            return g(a),

        jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
            f, [n, m, a, b], keep_inputs=[False, False, True, True])
        # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let
        #     e:f32[a] = xla_call[
        #       call_jaxpr={ lambda ; f:i32[] g:f32[f]. let  in (g,) }
        #       name=g
        #     ] a c
        #   in (e,) }
        core.check_jaxpr(jaxpr)  # no problems here...

        # Let's introduce a type error by applying the called jaxpr to arguments
        # with types which aren't consistent with its input binders:
        _, _, c, d = jaxpr.invars
        jaxpr.eqns[0].invars[1] = d
        # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let
        #     e:f32[a] = xla_call[
        #       call_jaxpr={ lambda ; f:i32[] g:f32[f]. let  in (g,) }
        #       name=g
        #     ] a d   !!! type error here !!!
        #   in (e,) }
        with self.assertRaisesRegex(TypeError, "passes operand"):
            core.check_jaxpr(jaxpr)

        # Restore the original jaxpr:
        jaxpr.eqns[0].invars[1] = c
        core.check_jaxpr(jaxpr)  # no problems here...

        # Let's introduce another type error by setting the call result let binders
        # to have the wrong type:
        jaxpr.eqns[0].outvars[0] = core.Var(0, '', d.aval)
        # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let
        #     e:f32[b] = xla_call[   !!! type error here !!!
        #       call_jaxpr={ lambda ; f:i32[] g:f32[f]. let  in (g,) }
        #       name=g
        #     ] a c
        #   in (h,) }
        with self.assertRaisesRegex(TypeError, "inconsistently typed as"):
            core.check_jaxpr(jaxpr)
コード例 #2
0
ファイル: conditionals.py プロジェクト: xueeinstein/jax
def cond_bind(*args, branches, linear):
    if config.jax_enable_checks:
        avals = _map(core.get_aval, args)
        in_atoms = [core.Var(0, '', a) for a in avals]  # dummies
        _cond_typecheck(*in_atoms, branches=branches, linear=linear)
        for jaxpr in branches:
            core.check_jaxpr(jaxpr.jaxpr)
    return core.AxisPrimitive.bind(cond_p,
                                   *args,
                                   branches=branches,
                                   linear=linear)
コード例 #3
0
 def mk_new_var(aval: core.AbstractValue) -> core.Var:
     return core.Var(next(mk_new_id), '', aval)