示例#1
0
    def test_binop_mismatch(self):
        def f(x, y):
            return x + y

        jtu.check_raises(lambda: grad(f)(onp.zeros(3), onp.zeros(4)),
                         ValueError,
                         "Incompatible shapes for broadcasting: ((3,), (4,))")
示例#2
0
  def test_dot_mismatch(self):
    def f(x, y):
      return np.dot(x, y)

    jtu.check_raises(lambda: grad(f)(onp.zeros(3), onp.zeros(4)),
                     TypeError,
                     "Incompatible shapes for dot: got (3,) and (4,).")
示例#3
0
    def test_unimplemented_interpreter_rules(self):
        foo_p = Primitive('foo')

        def foo(x):
            return foo_p.bind(x)

        jtu.check_raises(lambda: foo(1.0), NotImplementedError,
                         "Evaluation rule for 'foo' not implemented")

        jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError,
                         "Abstract evaluation for 'foo' not implemented")

        jtu.check_raises(
            lambda: grad(foo)(1.0), NotImplementedError,
            "Forward-mode differentiation rule for 'foo' not implemented")

        foo_p.def_abstract_eval(lambda x: x)

        jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError,
                         "XLA translation rule for 'foo' not implemented")

        foo_p.def_impl(lambda x: x)
        defjvp(foo_p, lambda g, x: foo(g))

        jtu.check_raises(
            lambda: grad(foo)(1.0), NotImplementedError,
            "Reverse-mode differentiation rule for 'foo' not implemented")
示例#4
0
文件: api_test.py 项目: mitghi/jax
  def test_unwrapped_numpy(self):
    def f(x):
      return onp.exp(x)

    jtu.check_raises(lambda: grad(f)(onp.zeros(3)), Exception,
                     "Tracer can't be used with raw numpy functions. "
                     "You might have\n  import numpy as np\ninstead of\n"
                     "  import jax.numpy as np")
示例#5
0
    def test_binop_mismatch(self):
        def f(x, y):
            return x + y

        jtu.check_raises(
            lambda: f(np.zeros(3), np.zeros(4)), TypeError,
            "add got incompatible shapes for broadcasting: (3,), (4,).")

        jtu.check_raises(
            lambda: grad(f)(onp.zeros(3), onp.zeros(4)), TypeError,
            "add got incompatible shapes for broadcasting: (3,), (4,).")
示例#6
0
文件: api_test.py 项目: mitghi/jax
  def test_switch_value_jit(self):
    def f(x):
      y = x > 0
      if y:
        return x
      else:
        return -x

    assert grad(f)(1.0) == 1.0
    assert grad(f)(-1.0) == -1.0
    jtu.check_raises(lambda: jit(f)(1), TypeError, concretization_err_msg(bool))
示例#7
0
文件: api_test.py 项目: yyht/jax
    def test_issue_871(self):
        T = np.array([[1., 2.], [3., 4.], [5., 6.]])
        x = np.array([1, 2, 3])

        y, f_jvp = api.linearize(np.sum, x)
        jtu.check_raises(lambda: f_jvp(T), ValueError,
                         ("linearized function called on tangent values "
                          "inconsistent with the original primal values."))

        y, f_jvp = api.linearize(api.jit(np.sum), x)
        jtu.check_raises(lambda: f_jvp(T), ValueError,
                         ("linearized function called on tangent values "
                          "inconsistent with the original primal values."))
示例#8
0
  def test_defvjp_closure_error(self):
    def foo(x):
      @api.custom_transforms
      def bar(y):
        return x * y

      api.defvjp(bar, lambda g, ans, y: x * y)
      return bar(x)
    jtu.check_raises(
        lambda: grad(foo)(1.,), ValueError,
        "Detected differentiation w.r.t. variables from outside "
        "the scope of <jax.custom_transforms function bar>, but defvjp and "
        "defvjp_all only support differentiation w.r.t. positional arguments.")
示例#9
0
 def test_grad_nonscalar_output(self):
     jtu.check_raises(
         lambda: grad(lambda x: x)(onp.zeros(3)), TypeError,
         "Gradient only defined for scalar-output functions. ")
示例#10
0
 def test_grad_tuple_output(self):
     jtu.check_raises(
         lambda: grad(lambda x: (x, x))(1.0), TypeError,
         "Gradient only defined for scalar-output functions. ")
示例#11
0
 def test_grad_unit_output(self):
     jtu.check_raises(
         lambda: grad(lambda x: ())(onp.zeros(3)), TypeError,
         "Gradient only defined for scalar-output functions. "
         "Output was: ()")