예제 #1
0
파일: core_test.py 프로젝트: orestmy/jax
    def test_check_jaxpr_eqn_mismatch(self):
        def f(x):
            return jnp.sin(x) + jnp.cos(x)

        def new_jaxpr():
            return make_jaxpr(f)(1.).jaxpr

        # jaxpr is:
        #
        # { lambda  ; a.
        #   let b = sin a
        #       c = cos a
        #       d = add b c
        #   in (d,) }
        #
        # NB: eqns[0].outvars[0] and eqns[2].invars[0] are both 'b'

        jaxpr = new_jaxpr()
        jaxpr.eqns[0].outvars[0].aval = make_shaped_array(2)  # int, not float!
        jtu.check_raises_regexp(
            lambda: core.check_jaxpr(jaxpr), TypeError,
            ("Jaxpr equation LHS .* is ShapedArray(.*), "
             "RHS is inferred as ShapedArray(.*), in '.* = sin .*'"))

        jaxpr = new_jaxpr()
        jaxpr.eqns[0].outvars[0].aval = make_shaped_array(np.ones((2, 3)))
        jtu.check_raises_regexp(
            lambda: core.check_jaxpr(jaxpr), TypeError,
            ("Jaxpr equation LHS .* is ShapedArray(.*), "
             "RHS is inferred as ShapedArray(.*), in '.* = sin .*'"))
예제 #2
0
    def test_check_jaxpr_eqn_mismatch(self):
        def f(x):
            return jnp.sin(x) + jnp.cos(x)

        def new_jaxpr():
            return make_jaxpr(f)(1.).jaxpr

        # jaxpr is:
        #
        # { lambda  ; a.
        #   let b = sin a
        #       c = cos a
        #       d = add b c
        #   in (d,) }
        #
        # NB: eqns[0].outvars[0] and eqns[2].invars[0] are both 'b'

        jaxpr = new_jaxpr()
        jaxpr.eqns[0].outvars[0].aval = make_shaped_array(2)  # int, not float!
        self.assertRaisesRegex(
            core.JaxprTypeError,
            r"Variable '.' inconsistently typed as ShapedArray(.*), "
            r"bound as ShapedArray(.*)\n\nin equation:\n\n  . = sin .",
            lambda: core.check_jaxpr(jaxpr))

        jaxpr = new_jaxpr()
        jaxpr.eqns[0].outvars[0].aval = make_shaped_array(np.ones((2, 3)))
        self.assertRaisesRegex(
            core.JaxprTypeError,
            r"Variable '.' inconsistently typed as ShapedArray(.*), "
            r"bound as ShapedArray(.*)\n\nin equation:\n\n  . = sin .",
            lambda: core.check_jaxpr(jaxpr))
예제 #3
0
파일: core_test.py 프로젝트: yueyedeai/jax
    def test_check_jaxpr_eqn_mismatch(self):
        def f(x):
            return jnp.sin(x) + jnp.cos(x)

        def new_jaxpr():
            return make_jaxpr(f)(1.).jaxpr

        # jaxpr is:
        #
        # { lambda  ; a.
        #   let b = sin a
        #       c = cos a
        #       d = add b c
        #   in (d,) }
        #
        # NB: eqns[0].outvars[0] and eqns[2].invars[0] are both 'b'

        jaxpr = new_jaxpr()
        jaxpr.eqns[0].outvars[0].aval = make_shaped_array(2)  # int, not float!
        jtu.check_raises_regexp(
            lambda: core.check_jaxpr(jaxpr), TypeError,
            (r"Variable '.' inconsistently typed as ShapedArray(.*), "
             r"bound as ShapedArray(.*) in '. = sin .'"))

        jaxpr = new_jaxpr()
        jaxpr.eqns[0].outvars[0].aval = make_shaped_array(np.ones((2, 3)))
        jtu.check_raises_regexp(
            lambda: core.check_jaxpr(jaxpr), TypeError,
            (r"Variable '.' inconsistently typed as ShapedArray(.*), "
             r"bound as ShapedArray(.*) in '. = sin .'"))