Ejemplo n.º 1
0
Archivo: api_test.py Proyecto: yyht/jax
    def test_dot_mismatch(self):
        def f(x, y):
            return np.dot(x, y)

        jtu.check_raises_regexp(
            lambda: grad(f)(onp.zeros(3), onp.zeros(4)), TypeError,
            "Incompatible shapes for dot: got \\(3L?,\\) and \\(4L?,\\).")
Ejemplo n.º 2
0
 def test_casts(self):
     for castfun in [float, complex, hex, oct] + list(six.integer_types):
         f = lambda x: castfun(x)
         jtu.check_raises_regexp(
             lambda: jit(f)(0), TypeError,
             "('JaxprTracer' object cannot be interpreted as an integer"
             "|Abstract value passed to .*)")
Ejemplo n.º 3
0
 def testMismatchedAxisSizes(self):
     n = xla_bridge.device_count()
     f = pmap(lambda x, y: x + y)
     jtu.check_raises_regexp(
         lambda: f(onp.random.randn(n), onp.random.randn(n - 1)),
         ValueError,
         "Axis size .* does not match leading dimension of shape .*")
Ejemplo n.º 4
0
    def test_check_jaxpr_eqn_mismatch(self):
        def f(x):
            return jnp.sin(x) + jnp.cos(x)

        def new_jaxpr():
            return make_jaxpr(f)(1.).jaxpr

        # jaxpr is:
        #
        # { lambda  ; a.
        #   let b = sin a
        #       c = cos a
        #       d = add b c
        #   in (d,) }
        #
        # NB: eqns[0].outvars[0] and eqns[2].invars[0] are both 'b'

        jaxpr = new_jaxpr()
        jaxpr.eqns[0].outvars[0].aval = make_shaped_array(2)  # int, not float!
        jtu.check_raises_regexp(
            lambda: core.check_jaxpr(jaxpr), TypeError,
            ("Jaxpr equation LHS .* is ShapedArray(.*), "
             "RHS is inferred as ShapedArray(.*), in '.* = sin .*'"))

        jaxpr = new_jaxpr()
        jaxpr.eqns[0].outvars[0].aval = make_shaped_array(np.ones((2, 3)))
        jtu.check_raises_regexp(
            lambda: core.check_jaxpr(jaxpr), TypeError,
            ("Jaxpr equation LHS .* is ShapedArray(.*), "
             "RHS is inferred as ShapedArray(.*), in '.* = sin .*'"))
Ejemplo n.º 5
0
    def test_check_jaxpr_eqn_mismatch(self):
        def f(x):
            return jnp.sin(x) + jnp.cos(x)

        def new_jaxpr():
            return make_jaxpr(f)(1.).jaxpr

        # jaxpr is:
        #
        # { lambda  ; a.
        #   let b = sin a
        #       c = cos a
        #       d = add b c
        #   in (d,) }
        #
        # NB: eqns[0].outvars[0] and eqns[2].invars[0] are both 'b'

        jaxpr = new_jaxpr()
        jaxpr.eqns[0].outvars[0].aval = make_shaped_array(2)  # int, not float!
        jtu.check_raises_regexp(
            lambda: core.check_jaxpr(jaxpr), TypeError,
            (r"Variable '.' inconsistently typed as ShapedArray(.*), "
             r"bound as ShapedArray(.*) in '. = sin .'"))

        jaxpr = new_jaxpr()
        jaxpr.eqns[0].outvars[0].aval = make_shaped_array(np.ones((2, 3)))
        jtu.check_raises_regexp(
            lambda: core.check_jaxpr(jaxpr), TypeError,
            (r"Variable '.' inconsistently typed as ShapedArray(.*), "
             r"bound as ShapedArray(.*) in '. = sin .'"))
Ejemplo n.º 6
0
 def test_vmap_in_axes_tree_prefix_error(self):
     # https://github.com/google/jax/issues/795
     jtu.check_raises_regexp(
         lambda: api.vmap(lambda x: x, in_axes=(0, 0))(np.ones(3)),
         ValueError,
         "axes specification must be a tree prefix of the corresponding "
         r"value, got specification \(0, 0\) for value "
         r"PyTreeDef\(tuple, \[\*\]\).")
Ejemplo n.º 7
0
  def test_bad_input(self):
    def f(x):
      return x

    jtu.check_raises_regexp(lambda: grad(f)("foo"), TypeError,
                     "Argument 'foo' of type <.*'str'> is not a valid JAX type")

    jtu.check_raises_regexp(lambda: jit(f)("foo"), TypeError,
                     "Argument 'foo' of type <.*'str'> is not a valid JAX type")
Ejemplo n.º 8
0
    def test_range_err(self):
        def f(x, n):
            for i in range(n):
                x = x + i
            return x

        assert jit(f, static_argnums=(1, ))(0, 5) == 10
        jtu.check_raises_regexp(
            lambda: jit(f)(0, 5), TypeError,
            "('JaxprTracer' object cannot be interpreted as an integer"
            "|Abstract value passed to .*)")
Ejemplo n.º 9
0
Archivo: api_test.py Proyecto: yyht/jax
 def test_jit_of_noncallable(self):
     jtu.check_raises_regexp(lambda: api.jit(3), TypeError,
                             "Expected a callable value.*")
Ejemplo n.º 10
0
Archivo: api_test.py Proyecto: yyht/jax
 def test_grad_of_int_errors(self):
     dfn = grad(lambda x: x**2)
     jtu.check_raises_regexp(
         lambda: dfn(3), TypeError,
         "Primal inputs to reverse-mode differentiation must be of float or "
         "complex type, got type int..")
Ejemplo n.º 11
0
Archivo: api_test.py Proyecto: yyht/jax
 def test_devicearray_delete(self):
     x = device_put(1.)
     x.delete()
     jtu.check_raises_regexp(lambda: repr(x), ValueError,
                             "DeviceValue has been deleted.")