def test_binop_mismatch(self): def f(x, y): return x + y jtu.check_raises(lambda: grad(f)(onp.zeros(3), onp.zeros(4)), ValueError, "Incompatible shapes for broadcasting: ((3,), (4,))")
def test_dot_mismatch(self): def f(x, y): return np.dot(x, y) jtu.check_raises(lambda: grad(f)(onp.zeros(3), onp.zeros(4)), TypeError, "Incompatible shapes for dot: got (3,) and (4,).")
def test_unimplemented_interpreter_rules(self): foo_p = Primitive('foo') def foo(x): return foo_p.bind(x) jtu.check_raises(lambda: foo(1.0), NotImplementedError, "Evaluation rule for 'foo' not implemented") jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError, "Abstract evaluation for 'foo' not implemented") jtu.check_raises( lambda: grad(foo)(1.0), NotImplementedError, "Forward-mode differentiation rule for 'foo' not implemented") foo_p.def_abstract_eval(lambda x: x) jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError, "XLA translation rule for 'foo' not implemented") foo_p.def_impl(lambda x: x) defjvp(foo_p, lambda g, x: foo(g)) jtu.check_raises( lambda: grad(foo)(1.0), NotImplementedError, "Reverse-mode differentiation rule for 'foo' not implemented")
def test_unwrapped_numpy(self): def f(x): return onp.exp(x) jtu.check_raises(lambda: grad(f)(onp.zeros(3)), Exception, "Tracer can't be used with raw numpy functions. " "You might have\n import numpy as np\ninstead of\n" " import jax.numpy as np")
def test_binop_mismatch(self): def f(x, y): return x + y jtu.check_raises( lambda: f(np.zeros(3), np.zeros(4)), TypeError, "add got incompatible shapes for broadcasting: (3,), (4,).") jtu.check_raises( lambda: grad(f)(onp.zeros(3), onp.zeros(4)), TypeError, "add got incompatible shapes for broadcasting: (3,), (4,).")
def test_switch_value_jit(self): def f(x): y = x > 0 if y: return x else: return -x assert grad(f)(1.0) == 1.0 assert grad(f)(-1.0) == -1.0 jtu.check_raises(lambda: jit(f)(1), TypeError, concretization_err_msg(bool))
def test_issue_871(self): T = np.array([[1., 2.], [3., 4.], [5., 6.]]) x = np.array([1, 2, 3]) y, f_jvp = api.linearize(np.sum, x) jtu.check_raises(lambda: f_jvp(T), ValueError, ("linearized function called on tangent values " "inconsistent with the original primal values.")) y, f_jvp = api.linearize(api.jit(np.sum), x) jtu.check_raises(lambda: f_jvp(T), ValueError, ("linearized function called on tangent values " "inconsistent with the original primal values."))
def test_defvjp_closure_error(self): def foo(x): @api.custom_transforms def bar(y): return x * y api.defvjp(bar, lambda g, ans, y: x * y) return bar(x) jtu.check_raises( lambda: grad(foo)(1.,), ValueError, "Detected differentiation w.r.t. variables from outside " "the scope of <jax.custom_transforms function bar>, but defvjp and " "defvjp_all only support differentiation w.r.t. positional arguments.")
def test_grad_nonscalar_output(self): jtu.check_raises( lambda: grad(lambda x: x)(onp.zeros(3)), TypeError, "Gradient only defined for scalar-output functions. ")
def test_grad_tuple_output(self): jtu.check_raises( lambda: grad(lambda x: (x, x))(1.0), TypeError, "Gradient only defined for scalar-output functions. ")
def test_grad_unit_output(self): jtu.check_raises( lambda: grad(lambda x: ())(onp.zeros(3)), TypeError, "Gradient only defined for scalar-output functions. " "Output was: ()")