def test_dot_mismatch(self): def f(x, y): return np.dot(x, y) jtu.check_raises_regexp( lambda: grad(f)(onp.zeros(3), onp.zeros(4)), TypeError, "Incompatible shapes for dot: got \\(3L?,\\) and \\(4L?,\\).")
def test_casts(self): for castfun in [float, complex, hex, oct] + list(six.integer_types): f = lambda x: castfun(x) jtu.check_raises_regexp( lambda: jit(f)(0), TypeError, "('JaxprTracer' object cannot be interpreted as an integer" "|Abstract value passed to .*)")
def testMismatchedAxisSizes(self): n = xla_bridge.device_count() f = pmap(lambda x, y: x + y) jtu.check_raises_regexp( lambda: f(onp.random.randn(n), onp.random.randn(n - 1)), ValueError, "Axis size .* does not match leading dimension of shape .*")
def test_check_jaxpr_eqn_mismatch(self): def f(x): return jnp.sin(x) + jnp.cos(x) def new_jaxpr(): return make_jaxpr(f)(1.).jaxpr # jaxpr is: # # { lambda ; a. # let b = sin a # c = cos a # d = add b c # in (d,) } # # NB: eqns[0].outvars[0] and eqns[2].invars[0] are both 'b' jaxpr = new_jaxpr() jaxpr.eqns[0].outvars[0].aval = make_shaped_array(2) # int, not float! jtu.check_raises_regexp( lambda: core.check_jaxpr(jaxpr), TypeError, ("Jaxpr equation LHS .* is ShapedArray(.*), " "RHS is inferred as ShapedArray(.*), in '.* = sin .*'")) jaxpr = new_jaxpr() jaxpr.eqns[0].outvars[0].aval = make_shaped_array(np.ones((2, 3))) jtu.check_raises_regexp( lambda: core.check_jaxpr(jaxpr), TypeError, ("Jaxpr equation LHS .* is ShapedArray(.*), " "RHS is inferred as ShapedArray(.*), in '.* = sin .*'"))
def test_check_jaxpr_eqn_mismatch(self): def f(x): return jnp.sin(x) + jnp.cos(x) def new_jaxpr(): return make_jaxpr(f)(1.).jaxpr # jaxpr is: # # { lambda ; a. # let b = sin a # c = cos a # d = add b c # in (d,) } # # NB: eqns[0].outvars[0] and eqns[2].invars[0] are both 'b' jaxpr = new_jaxpr() jaxpr.eqns[0].outvars[0].aval = make_shaped_array(2) # int, not float! jtu.check_raises_regexp( lambda: core.check_jaxpr(jaxpr), TypeError, (r"Variable '.' inconsistently typed as ShapedArray(.*), " r"bound as ShapedArray(.*) in '. = sin .'")) jaxpr = new_jaxpr() jaxpr.eqns[0].outvars[0].aval = make_shaped_array(np.ones((2, 3))) jtu.check_raises_regexp( lambda: core.check_jaxpr(jaxpr), TypeError, (r"Variable '.' inconsistently typed as ShapedArray(.*), " r"bound as ShapedArray(.*) in '. = sin .'"))
def test_vmap_in_axes_tree_prefix_error(self): # https://github.com/google/jax/issues/795 jtu.check_raises_regexp( lambda: api.vmap(lambda x: x, in_axes=(0, 0))(np.ones(3)), ValueError, "axes specification must be a tree prefix of the corresponding " r"value, got specification \(0, 0\) for value " r"PyTreeDef\(tuple, \[\*\]\).")
def test_bad_input(self): def f(x): return x jtu.check_raises_regexp(lambda: grad(f)("foo"), TypeError, "Argument 'foo' of type <.*'str'> is not a valid JAX type") jtu.check_raises_regexp(lambda: jit(f)("foo"), TypeError, "Argument 'foo' of type <.*'str'> is not a valid JAX type")
def test_range_err(self): def f(x, n): for i in range(n): x = x + i return x assert jit(f, static_argnums=(1, ))(0, 5) == 10 jtu.check_raises_regexp( lambda: jit(f)(0, 5), TypeError, "('JaxprTracer' object cannot be interpreted as an integer" "|Abstract value passed to .*)")
def test_jit_of_noncallable(self): jtu.check_raises_regexp(lambda: api.jit(3), TypeError, "Expected a callable value.*")
def test_grad_of_int_errors(self): dfn = grad(lambda x: x**2) jtu.check_raises_regexp( lambda: dfn(3), TypeError, "Primal inputs to reverse-mode differentiation must be of float or " "complex type, got type int..")
def test_devicearray_delete(self): x = device_put(1.) x.delete() jtu.check_raises_regexp(lambda: repr(x), ValueError, "DeviceValue has been deleted.")