コード例 #1
0
ファイル: api_test.py プロジェクト: yyht/jax
    def test_dot_mismatch(self):
        def f(x, y):
            return np.dot(x, y)

        jtu.check_raises_regexp(
            lambda: grad(f)(onp.zeros(3), onp.zeros(4)), TypeError,
            "Incompatible shapes for dot: got \\(3L?,\\) and \\(4L?,\\).")
コード例 #2
0
 def test_casts(self):
     for castfun in [float, complex, hex, oct] + list(six.integer_types):
         f = lambda x: castfun(x)
         jtu.check_raises_regexp(
             lambda: jit(f)(0), TypeError,
             "('JaxprTracer' object cannot be interpreted as an integer"
             "|Abstract value passed to .*)")
コード例 #3
0
 def testMismatchedAxisSizes(self):
     n = xla_bridge.device_count()
     f = pmap(lambda x, y: x + y)
     jtu.check_raises_regexp(
         lambda: f(onp.random.randn(n), onp.random.randn(n - 1)),
         ValueError,
         "Axis size .* does not match leading dimension of shape .*")
コード例 #4
0
ファイル: core_test.py プロジェクト: orestmy/jax
    def test_check_jaxpr_eqn_mismatch(self):
        def f(x):
            return jnp.sin(x) + jnp.cos(x)

        def new_jaxpr():
            return make_jaxpr(f)(1.).jaxpr

        # jaxpr is:
        #
        # { lambda  ; a.
        #   let b = sin a
        #       c = cos a
        #       d = add b c
        #   in (d,) }
        #
        # NB: eqns[0].outvars[0] and eqns[2].invars[0] are both 'b'

        jaxpr = new_jaxpr()
        jaxpr.eqns[0].outvars[0].aval = make_shaped_array(2)  # int, not float!
        jtu.check_raises_regexp(
            lambda: core.check_jaxpr(jaxpr), TypeError,
            ("Jaxpr equation LHS .* is ShapedArray(.*), "
             "RHS is inferred as ShapedArray(.*), in '.* = sin .*'"))

        jaxpr = new_jaxpr()
        jaxpr.eqns[0].outvars[0].aval = make_shaped_array(np.ones((2, 3)))
        jtu.check_raises_regexp(
            lambda: core.check_jaxpr(jaxpr), TypeError,
            ("Jaxpr equation LHS .* is ShapedArray(.*), "
             "RHS is inferred as ShapedArray(.*), in '.* = sin .*'"))
コード例 #5
0
ファイル: core_test.py プロジェクト: yueyedeai/jax
    def test_check_jaxpr_eqn_mismatch(self):
        def f(x):
            return jnp.sin(x) + jnp.cos(x)

        def new_jaxpr():
            return make_jaxpr(f)(1.).jaxpr

        # jaxpr is:
        #
        # { lambda  ; a.
        #   let b = sin a
        #       c = cos a
        #       d = add b c
        #   in (d,) }
        #
        # NB: eqns[0].outvars[0] and eqns[2].invars[0] are both 'b'

        jaxpr = new_jaxpr()
        jaxpr.eqns[0].outvars[0].aval = make_shaped_array(2)  # int, not float!
        jtu.check_raises_regexp(
            lambda: core.check_jaxpr(jaxpr), TypeError,
            (r"Variable '.' inconsistently typed as ShapedArray(.*), "
             r"bound as ShapedArray(.*) in '. = sin .'"))

        jaxpr = new_jaxpr()
        jaxpr.eqns[0].outvars[0].aval = make_shaped_array(np.ones((2, 3)))
        jtu.check_raises_regexp(
            lambda: core.check_jaxpr(jaxpr), TypeError,
            (r"Variable '.' inconsistently typed as ShapedArray(.*), "
             r"bound as ShapedArray(.*) in '. = sin .'"))
コード例 #6
0
ファイル: api_test.py プロジェクト: chandrad143/jax
 def test_vmap_in_axes_tree_prefix_error(self):
     # https://github.com/google/jax/issues/795
     jtu.check_raises_regexp(
         lambda: api.vmap(lambda x: x, in_axes=(0, 0))(np.ones(3)),
         ValueError,
         "axes specification must be a tree prefix of the corresponding "
         r"value, got specification \(0, 0\) for value "
         r"PyTreeDef\(tuple, \[\*\]\).")
コード例 #7
0
ファイル: api_test.py プロジェクト: terry2012/jax
  def test_bad_input(self):
    def f(x):
      return x

    jtu.check_raises_regexp(lambda: grad(f)("foo"), TypeError,
                     "Argument 'foo' of type <.*'str'> is not a valid JAX type")

    jtu.check_raises_regexp(lambda: jit(f)("foo"), TypeError,
                     "Argument 'foo' of type <.*'str'> is not a valid JAX type")
コード例 #8
0
    def test_range_err(self):
        def f(x, n):
            for i in range(n):
                x = x + i
            return x

        assert jit(f, static_argnums=(1, ))(0, 5) == 10
        jtu.check_raises_regexp(
            lambda: jit(f)(0, 5), TypeError,
            "('JaxprTracer' object cannot be interpreted as an integer"
            "|Abstract value passed to .*)")
コード例 #9
0
ファイル: api_test.py プロジェクト: yyht/jax
 def test_jit_of_noncallable(self):
     jtu.check_raises_regexp(lambda: api.jit(3), TypeError,
                             "Expected a callable value.*")
コード例 #10
0
ファイル: api_test.py プロジェクト: yyht/jax
 def test_grad_of_int_errors(self):
     dfn = grad(lambda x: x**2)
     jtu.check_raises_regexp(
         lambda: dfn(3), TypeError,
         "Primal inputs to reverse-mode differentiation must be of float or "
         "complex type, got type int..")
コード例 #11
0
ファイル: api_test.py プロジェクト: yyht/jax
 def test_devicearray_delete(self):
     x = device_put(1.)
     x.delete()
     jtu.check_raises_regexp(lambda: repr(x), ValueError,
                             "DeviceValue has been deleted.")