Beispiel #1
0
 def testTranspose(self, shape, dtype, perm, bdims):
     rng = jtu.rand_default(self.rng())
     op = lambda x: lax.transpose(x, perm)
     self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng)
Beispiel #2
0
class LaxBackedNumpyTests(jtu.JaxTestCase):
  """Tests for LAX-backed Numpy implementation."""

  def _GetArgsMaker(self, rng, shapes, dtypes):
    return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]

  @parameterized.named_parameters(
      {"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
                                                    dtypes),
       "rng": rec.rng, "shapes": shapes, "dtypes": dtypes,
       "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name)}
      for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS,
                                 JAX_COMPOUND_OP_RECORDS)
      for shapes in CombosWithReplacement(all_shapes, rec.nargs)
      for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs))
  def testOp(self, onp_op, lnp_op, rng, shapes, dtypes):
    args_maker = self._GetArgsMaker(rng, shapes, dtypes)
    self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format(
          rec.test_name.capitalize(),
          jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
       "rng": rec.rng, "shape": shape, "dtype": dtype,
       "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name),
       "axis": axis, "keepdims": keepdims}
      for rec in JAX_REDUCER_RECORDS
      for shape in all_shapes for dtype in rec.dtypes
      for axis in range(-len(shape), len(shape))
      for keepdims in [False, True])
  def testReducer(self, onp_op, lnp_op, rng, shape, dtype, axis, keepdims):
    onp_fun = lambda x: onp_op(x, axis, keepdims=keepdims)
    lnp_fun = lambda x: lnp_op(x, axis, keepdims=keepdims)
    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "{}_inshape={}_axis={}".format(
          rec.test_name.capitalize(),
          jtu.format_shape_dtype_string(shape, dtype), axis),
       "rng": rec.rng, "shape": shape, "dtype": dtype,
       "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name),
       "axis": axis}
      for rec in JAX_ARGMINMAX_RECORDS
      for shape in all_shapes for dtype in rec.dtypes
      for axis in range(-len(shape), len(shape)))
  def testArgMinMax(self, onp_op, lnp_op, rng, shape, dtype, axis):

    def onp_fun(array_to_reduce):
      return onp_op(array_to_reduce, axis)

    def lnp_fun(array_to_reduce):
      return lnp_op(array_to_reduce, axis)

    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "_{}_{}_{}".format(
          name,
          jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
          jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
       "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
       "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
       "rng": rng}
      for rng in [jtu.rand_default()]
      for name, lhs_shape, rhs_shape in [
          ("matrix-scalar", (3, 3), ()),
          ("scalar-matrix", (), (3, 3)),
          ("matrix-vector", (4, 5), (5,)),
          ("vector-matrix", (6,), (6, 4)),
          ("matrix-matrix", (3, 4), (4, 5)),
          ("tensor-vector", (4, 3, 2), (2,)),
          ("vector-tensor", (2,), (3, 2, 4)),
          ("tensor-matrix", (4, 3, 2), (2, 5)),
          ("matrix-tensor", (5, 2), (3, 2, 4)),
          ("tensor-tensor", (2, 3, 4), (5, 4, 1))]
      for lhs_dtype, rhs_dtype in CombosWithReplacement(float_dtypes, 2))
  def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng):
    args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
    self._CheckAgainstNumpy(onp.dot, lnp.dot, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp.dot, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "_{}_{}_{}".format(
          name,
          jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
          jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
       "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
       "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
       "rng": rng}
      for rng in [jtu.rand_default()]
      for name, lhs_shape, rhs_shape in [
          ("vector-vector", (3,), (3,)),
          ("matrix-vector", (3, 3), (3,)),
          ("vector-matrix", (3,), (3, 3)),
          ("matrix-matrix", (3, 3), (3, 3)),
          ("vector-tensor", (3,), (5, 3, 2)),
          ("tensor-vector", (5, 3, 2), (2,)),
          ("matrix-tensor", (5, 2), (3, 2, 4)),
          ("tensor-matrix", (5, 2, 3), (3, 2)),
          ("tensor-tensor", (5, 3, 4), (5, 4, 1)),
          ("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1))]
      for lhs_dtype, rhs_dtype in CombosWithReplacement(float_dtypes, 2))
  def testMatmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng):
    args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
    self._CheckAgainstNumpy(onp.matmul, lnp.matmul, args_maker,
                            check_dtypes=True)
    self._CompileAndCheck(lnp.matmul, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "_{}_amin={}_amax={}".format(
          jtu.format_shape_dtype_string(shape, dtype), a_min, a_max),
       "shape": shape, "dtype": dtype, "a_min": a_min, "a_max": a_max,
       "rng": jtu.rand_default()}
      for shape in all_shapes for dtype in float_dtypes
      for a_min, a_max in [(-1, None), (None, 1), (-1, 1)])
  def testClipStaticBounds(self, shape, dtype, a_min, a_max, rng):
    onp_fun = lambda x: onp.clip(x, a_min=a_min, a_max=a_max)
    lnp_fun = lambda x: lnp.clip(x, a_min=a_min, a_max=a_max)
    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "_{}_decimals={}".format(
          jtu.format_shape_dtype_string(shape, dtype), decimals),
       "shape": shape, "dtype": dtype, "decimals": decimals,
       "rng": jtu.rand_default()}
      for shape in all_shapes for dtype in float_dtypes
      for decimals in [0, 1, -2])
  def testRoundStaticDecimals(self, shape, dtype, decimals, rng):
    onp_fun = lambda x: onp.round(x, decimals=decimals)
    lnp_fun = lambda x: lnp.round(x, decimals=decimals)
    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format(
          axis, ",".join(str(d) for d in base_shape),
          ",".join(onp.dtype(dtype).name for dtype in dtypes)),
       "axis": axis, "base_shape": base_shape, "dtypes": dtypes,
       "rng": jtu.rand_default()}
      for num_arrs in [3]
      for dtypes in CombosWithReplacement(default_dtypes, num_arrs)
      for base_shape in [(4,), (3, 4), (2, 3, 4)]
      for axis in range(-len(base_shape)+1, len(base_shape)))
  def testConcatenate(self, axis, base_shape, dtypes, rng):
    wrapped_axis = axis % len(base_shape)
    shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:]
              for size, _ in zip(itertools.cycle([3, 1, 4]), dtypes)]
    onp_fun = lambda *args: onp.concatenate(args, axis=axis)
    lnp_fun = lambda *args: lnp.concatenate(args, axis=axis)

    def args_maker():
      return [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]

    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "_{}".format(
          jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes)),
       "shape": shape, "dtypes": dtypes, "rng": rng}
      for dtypes in [
        [onp.float32],
        [onp.float32, onp.float32],
        [onp.float32, onp.int32, onp.float32],
        [onp.float32, onp.int64, onp.float32],
        [onp.float32, onp.int32, onp.float64],
      ]
      for shape in [(), (2,), (3, 4), (1, 100)]
      for rng in [jtu.rand_default()])
  def testStack(self, shape, dtypes, rng):
    args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
    self._CheckAgainstNumpy(lnp.stack, onp.stack, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "_inshape=[{}]_indtype={}_outdtype={}".format(
          "_".join(str(d) for d in shape),
          onp.dtype(fill_value_dtype).name, onp.dtype(out_dtype).name),
       "shape": shape, "fill_value_dtype": fill_value_dtype,
       "out_dtype": out_dtype, "rng": jtu.rand_default()}
      for shape in all_shapes
      for fill_value_dtype in default_dtypes
      for out_dtype in default_dtypes)
  def testFull(self, shape, fill_value_dtype, out_dtype, rng):
    onp_fun = lambda fill_value: onp.full(shape, fill_value, dtype=out_dtype)
    lnp_fun = lambda fill_value: lnp.full(shape, fill_value, dtype=out_dtype)
    args_maker = lambda: [rng((), fill_value_dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "_{}_axis={}_{}sections".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, num_sections),
       "shape": shape, "num_sections": num_sections, "axis": axis,
       "dtype": dtype, "rng": jtu.rand_default()}
      for shape, axis, num_sections in [
          ((3,), 0, 3), ((12,), 0, 3), ((12, 4), 0, 4), ((12, 4), 1, 2),
          ((2, 3, 4), -1, 2), ((2, 3, 4), -2, 3)]
      for dtype in default_dtypes)
  def testSplitStaticInt(self, shape, num_sections, axis, dtype, rng):
    onp_fun = lambda x: onp.split(x, num_sections, axis=axis)
    lnp_fun = lambda x: lnp.split(x, num_sections, axis=axis)
    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "_inshape={}_outshape={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          jtu.format_shape_dtype_string(out_shape, dtype)),
       "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
       "rng": jtu.rand_default()}
      for dtype in default_dtypes
      for arg_shape, out_shape in [
          ((3, 4), 12),
          ((3, 4), (12,)),
          ((3, 4), -1),
          ((2, 1, 4), (-1,)),
          ((2, 2, 4), (2, 8))
      ])
  def testReshape(self, arg_shape, out_shape, dtype, rng):
    onp_fun = lambda x: onp.reshape(x, out_shape)
    lnp_fun = lambda x: lnp.reshape(x, out_shape)
    args_maker = lambda: [rng(arg_shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "_inshape={}_expanddim={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype), dim),
       "arg_shape": arg_shape, "dtype": dtype, "dim": dim,
       "rng": jtu.rand_default()}
      for arg_shape in [(), (3,), (3, 4)]
      for dtype in default_dtypes
      for dim in range(-len(arg_shape)+1, len(arg_shape)))
  def testExpandDimsStaticDim(self, arg_shape, dtype, dim, rng):
    onp_fun = lambda x: onp.expand_dims(x, dim)
    lnp_fun = lambda x: lnp.expand_dims(x, dim)
    args_maker = lambda: [rng(arg_shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "_inshape={}_axes=({},{})".format(
          jtu.format_shape_dtype_string(arg_shape, dtype), ax1, ax2),
       "arg_shape": arg_shape, "dtype": dtype, "ax1": ax1, "ax2": ax2,
       "rng": jtu.rand_default()}
      for arg_shape, ax1, ax2 in [
          ((3, 4), 0, 1), ((3, 4), 1, 0), ((3, 4, 5), 1, 2),
          ((3, 4, 5), -1, -2), ((3, 4, 5), 0, 1)]
      for dtype in default_dtypes)
  def testSwapAxesStaticAxes(self, arg_shape, dtype, ax1, ax2, rng):
    onp_fun = lambda x: onp.swapaxes(x, ax1, ax2)
    lnp_fun = lambda x: lnp.swapaxes(x, ax1, ax2)
    args_maker = lambda: [rng(arg_shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "_inshape={}_axis={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype), ax),
       "arg_shape": arg_shape, "dtype": dtype, "ax": ax,
       "rng": jtu.rand_default()}
      for arg_shape, ax in [
          ((3, 1), None),
          ((3, 1), 1),
          ((1, 3, 1), (0, 2)),
          ((1, 4, 1), (0,))]
      for dtype in default_dtypes)
  def testSqueeze(self, arg_shape, dtype, ax, rng):
    onp_fun = lambda x: onp.squeeze(x, ax)
    lnp_fun = lambda x: lnp.squeeze(x, ax)
    args_maker = lambda: [rng(arg_shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "_arg{}".format(i), "arg": arg}
      for i, arg in enumerate([
          [1, 2, 3], [1., 2., 3.],
          [[1, 2], [3, 4], [5, 6]], [[1, 2.], [3, 4], [5, 6]],
          [[3, onp.array(2), 1], onp.arange(3.)],
      ]))
  def testArray(self, arg):
    args_maker = lambda: [arg]
    self._CheckAgainstNumpy(onp.array, lnp.array, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp.array, args_maker, check_dtypes=True)

  def testAllClose(self):
    rng = onp.random.RandomState(0)
    x = rng.randn(2, 2)
    y = rng.randn(2)

    def same(list1, list2):
      allclose = functools.partial(lnp.allclose, atol=1e-3, rtol=1e-3)
      elements_close = list(map(allclose, list1, list2))
      return lnp.all(lnp.array(elements_close))

    csame = api.jit(same)

    a1 = same((x, y), (x, y))
    a2 = csame((x, y), (x, y))
    a3 = csame((x, y), (x, 2 * y))

    self.assertTrue(a1)
    self.assertTrue(a2)
    self.assertFalse(a3)

  @jtu.skip_on_devices("tpu")  # TODO(mattjj): investigate this failure
  def DISABLED_testOnesBroadcastingConstantHandler(self):
    # TODO(mattjj): update this test for jax3

    def fun(x):
      ones = lnp.ones((3, 4))
      assert isinstance(ones, onp.ndarray) and ones.strides == (0, 0)

      # To check that the constant handler generates a Broadcast for stride-zero
      # arrays, we monkey-patch the client instance.
      # TODO(mattjj): once we have better HLO dumping and inspecting facilities,
      # we can check the HLO more directly.
      c = x._node.c
      Broadcast = c.Broadcast  # pylint: disable=invalid-name
      was_called = []
      c.Broadcast = lambda *args: was_called.append(True) or Broadcast(*args)
      out = x + ones  # the ndarray constant handler should call Broadcast here
      assert was_called, "Broadcast was not called."

      return out

    fun = api.jit(fun)
    out_val = fun(lnp.ones(4))
    self.assertAllClose(out_val, onp.full((3, 4), 2.), check_dtypes=False)

  def testZeroStridesConstantHandler(self):
    raw_const = onp.random.RandomState(0).randn(1, 2, 1, 1, 5, 1)
    const = onp.broadcast_to(raw_const, (3, 2, 3, 4, 5, 6))

    def fun(x):
      return x * const

    fun = api.jit(fun)
    out_val = fun(3.)
    self.assertAllClose(out_val, 3. * const, check_dtypes=False)

  def testIsInstanceNdarrayDuringTracing(self):
    arr = onp.ones(3)

    @api.jit
    def f(x):
      self.assertIsInstance(x, lnp.ndarray)
      return lnp.sum(x)

    f(arr)


  def testNonArrayErrorMessage(self):
    x = [1., 2.]
    y = onp.array([3., 4.])

    def g(x, y):
      return lnp.add(x, y)

    def f(x, y):
      return lnp.dot(x, y)

    self.assertRaises(TypeError, lambda: g(x, y))
    self.assertRaises(TypeError, lambda: f(x, y))
    self.assertRaises(TypeError, lambda: api.jit(g)(x, y))
    self.assertRaises(TypeError, lambda: api.jit(f)(x, y))

  def testAbstractionErrorMessage(self):

    @api.jit
    def f(x, n):
      for _ in range(n):
        x = x * x
      return x

    self.assertRaises(TypeError, lambda: f(3., 3))

    @api.jit
    def g(x):
      if x > 0.:
        return x * 2
      else:
        return x + 2

    self.assertRaises(TypeError, lambda: g(3.))

  def DISABLED_testTracingPrimitiveWithNoTranslationErrorMessage(self):
    # TODO(mattjj): update this for jax3
    foo = lnp._not_implemented(lambda x: x)

    # No error if there's no tracing.
    foo(onp.arange(3))

    cfoo = api.jit(foo)
    self.assertRaises(NotImplementedError, lambda: cfoo(onp.arange(3)))

  # TODO(mattjj): test infix operator overrides

  def DISABLED_testRavel(self):
    # TODO(mattjj): support this method-based syntax?
    rng = onp.random.RandomState(0)
    args_maker = lambda: [rng.randn(3, 4).astype("float32")]
    self._CompileAndCheck(lambda x: x.ravel(), args_maker, check_dtypes=True)
Beispiel #3
0
class NumpyLinalgTest(jtu.JaxTestCase):

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype, "rng": rng}
      for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (1000, 0, 0)]
      for dtype in float_types()
      for rng in [jtu.rand_default()]))
  def testCholesky(self, shape, dtype, rng):
    def args_maker():
      factor_shape = shape[:-1] + (2 * shape[-1],)
      a = rng(factor_shape, dtype)
      return [onp.matmul(a, np.conj(T(a)))]

    self._CheckAgainstNumpy(onp.linalg.cholesky, np.linalg.cholesky, args_maker,
                            check_dtypes=True, tol=1e-3)
    self._CompileAndCheck(np.linalg.cholesky, args_maker, check_dtypes=True)

    if onp.finfo(dtype).bits == 64:
      jtu.check_grads(np.linalg.cholesky, args_maker(), order=2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)),
       "n": n, "dtype": dtype, "rng": rng}
      for n in [0, 4, 5, 25]  # TODO(mattjj): complex64 unstable on large sizes?
      for dtype in float_types() + complex_types()
      for rng in [jtu.rand_default()]))
  # TODO(phawkins): enable when there is an LU implementation for GPU/TPU.
  @jtu.skip_on_devices("gpu", "tpu")
  def testDet(self, n, dtype, rng):
    args_maker = lambda: [rng((n, n), dtype)]

    self._CheckAgainstNumpy(onp.linalg.det, np.linalg.det, args_maker,
                            check_dtypes=True, tol=1e-3)
    self._CompileAndCheck(np.linalg.det, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)),
       "n": n, "dtype": dtype, "rng": rng}
      for n in [0, 4, 10, 200]
      for dtype in float_types() + complex_types()
      for rng in [jtu.rand_default()]))
  @jtu.skip_on_devices("gpu", "tpu")
  def testSlogdet(self, n, dtype, rng):
    args_maker = lambda: [rng((n, n), dtype)]

    self._CheckAgainstNumpy(onp.linalg.slogdet, np.linalg.slogdet, args_maker,
                            check_dtypes=True, tol=1e-3)
    self._CompileAndCheck(np.linalg.slogdet, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}".format(
           jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype, "rng": rng}
      for shape in [(0, 0), (4, 4), (5, 5), (50, 50), (2, 6, 6)]
      for dtype in float_types() + complex_types()
      for rng in [jtu.rand_default()]))
  # TODO(phawkins): enable when there is an eigendecomposition implementation
  # for GPU/TPU.
  @jtu.skip_on_devices("gpu", "tpu")
  def testEig(self, shape, dtype, rng):
    self.skipTest("Test disabled until Jaxlib 0.1.15 is released") # TODO(phawkins)
    n = shape[-1]
    args_maker = lambda: [rng(shape, dtype)]

    # Norm, adjusted for dimension and type.
    def norm(x):
      norm = onp.linalg.norm(x, axis=(-2, -1))
      return norm / ((n + 1) * onp.finfo(dtype).eps)

    a, = args_maker()
    w, v = np.linalg.eig(a)
    self.assertTrue(onp.all(norm(onp.matmul(a, v) - w[..., None, :] * v) < 100))

    self._CompileAndCheck(partial(np.linalg.eig), args_maker,
                          check_dtypes=True, rtol=1e-3)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype, "rng": rng}
      for shape in [(1, 1), (4, 4), (5, 5)]
      for dtype in float_types() + complex_types()
      for rng in [jtu.rand_default()]))
  @jtu.skip_on_devices("gpu", "tpu")
  def testEigBatching(self, shape, dtype, rng):
    self.skipTest("Test disabled until Jaxlib 0.1.15 is released") # TODO(phawkins)
    shape = (10,) + shape
    args = rng(shape, dtype)
    ws, vs = vmap(np.linalg.eig)(args)
    self.assertTrue(onp.all(onp.linalg.norm(
        onp.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_n={}_lower={}".format(
           jtu.format_shape_dtype_string((n,n), dtype), lower),
       "n": n, "dtype": dtype, "lower": lower, "rng": rng}
      for n in [0, 4, 5, 50]
      for dtype in float_types() + complex_types()
      for lower in [False, True]
      for rng in [jtu.rand_default()]))
  # TODO(phawkins): enable when there is an eigendecomposition implementation
  # for GPU/TPU.
  @jtu.skip_on_devices("gpu", "tpu")
  def testEigh(self, n, dtype, lower, rng):
    args_maker = lambda: [rng((n, n), dtype)]

    uplo = "L" if lower else "U"

    # Norm, adjusted for dimension and type.
    def norm(x):
      norm = onp.linalg.norm(x, axis=(-2, -1))
      return norm / ((n + 1) * onp.finfo(dtype).eps)

    a, = args_maker()
    a = (a + onp.conj(a.T)) / 2
    w, v = np.linalg.eigh(onp.tril(a) if lower else onp.triu(a),
                          UPLO=uplo, symmetrize_input=False)
    self.assertTrue(norm(onp.eye(n) - onp.matmul(onp.conj(T(v)), v)) < 5)
    self.assertTrue(norm(onp.matmul(a, v) - w * v) < 30)

    self._CompileAndCheck(partial(np.linalg.eigh, UPLO=uplo), args_maker,
                          check_dtypes=True, rtol=1e-3)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}_lower={}".format(jtu.format_shape_dtype_string(shape, dtype),
                                   lower),
       "shape": shape, "dtype": dtype, "rng": rng, "lower":lower}
      for shape in [(1, 1), (4, 4), (5, 5), (50, 50)]
      for dtype in float_types() + complex_types()
      for rng in [jtu.rand_default()]
      for lower in [True, False]))
  # TODO(phawkins): enable when there is an eigendecomposition implementation
  # for GPU/TPU.
  @jtu.skip_on_devices("gpu", "tpu")
  def testEighGrad(self, shape, dtype, rng, lower):
    self.skipTest("Test fails with numeric errors.")
    uplo = "L" if lower else "U"
    a = rng(shape, dtype)
    a = (a + onp.conj(a.T)) / 2
    a = onp.tril(a) if lower else onp.triu(a)
    # Gradient checks will fail without symmetrization as the eigh jvp rule
    # is only correct for tangents in the symmetric subspace, whereas the
    # checker checks against unconstrained (co)tangents.
    if dtype not in complex_types():
      f = partial(np.linalg.eigh, UPLO=uplo, symmetrize_input=True)
    else:  # only check eigenvalue grads for complex matrices
      f = lambda a: partial(np.linalg.eigh, UPLO=uplo, symmetrize_input=True)(a)[0]
    jtu.check_grads(f, (a,), 2, rtol=1e-1)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}_lower={}".format(jtu.format_shape_dtype_string(shape, dtype),
                                   lower),
       "shape": shape, "dtype": dtype, "rng": rng, "lower":lower, "eps":eps}
      for shape in [(1, 1), (4, 4), (5, 5), (50, 50)]
      for dtype in complex_types()
      for rng in [jtu.rand_default()]
      for lower in [True, False]
      for eps in [1e-4]))
  # TODO(phawkins): enable when there is an eigendecomposition implementation
  # for GPU/TPU.
  @jtu.skip_on_devices("gpu", "tpu")
  def testEighGradVectorComplex(self, shape, dtype, rng, lower, eps):
    # Special case to test for complex eigenvector grad correctness.
    # Exact eigenvector coordinate gradients are hard to test numerically for complex
    # eigensystem solvers given the extra degrees of per-eigenvector phase freedom.
    # Instead, we numerically verify the eigensystem properties on the perturbed
    # eigenvectors.  You only ever want to optimize eigenvector directions, not coordinates!
    uplo = "L" if lower else "U"
    a = rng(shape, dtype)
    a = (a + onp.conj(a.T)) / 2
    a = onp.tril(a) if lower else onp.triu(a)
    a_dot = eps * rng(shape, dtype)
    a_dot = (a_dot + onp.conj(a_dot.T)) / 2
    a_dot = onp.tril(a_dot) if lower else onp.triu(a_dot)
    # evaluate eigenvector gradient and groundtruth eigensystem for perturbed input matrix
    f = partial(np.linalg.eigh, UPLO=uplo)
    (w, v), (dw, dv) = jvp(f, primals=(a,), tangents=(a_dot,))
    new_a = a + a_dot
    new_w, new_v = f(new_a)
    new_a = (new_a + onp.conj(new_a.T)) / 2
    # Assert rtol eigenvalue delta between perturbed eigenvectors vs new true eigenvalues.
    RTOL=1e-2
    assert onp.max(
      onp.abs((onp.diag(onp.dot(onp.conj((v+dv).T), onp.dot(new_a,(v+dv)))) - new_w) / new_w)) < RTOL
    # Redundant to above, but also assert rtol for eigenvector property with new true eigenvalues.
    assert onp.max(
      onp.linalg.norm(onp.abs(new_w*(v+dv) - onp.dot(new_a, (v+dv))), axis=0) /
      onp.linalg.norm(onp.abs(new_w*(v+dv)), axis=0)
    ) < RTOL

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype, "rng": rng}
      for shape in [(1, 1), (4, 4), (5, 5)]
      for dtype in float_types() + complex_types()
      for rng in [jtu.rand_default()]))
  @jtu.skip_on_devices("gpu", "tpu")
  def testEighBatching(self, shape, dtype, rng):
    self.skipTest("Test disabled until Jaxlib 0.1.15 is released") # TODO(phawkins)
    shape = (10,) + shape
    args = rng(shape, dtype)
    args = (args + onp.conj(T(args))) / 2
    ws, vs = vmap(jsp.linalg.eigh)(args)
    self.assertTrue(onp.all(onp.linalg.norm(
        onp.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_ord={}_axis={}_keepdims={}".format(
         jtu.format_shape_dtype_string(shape, dtype), ord, axis, keepdims),
       "shape": shape, "dtype": dtype, "axis": axis, "keepdims": keepdims,
       "ord": ord, "rng": rng}
      for axis, shape in [
        (None, (1,)), (None, (7,)), (None, (5, 8)),
        (0, (9,)), (0, (4, 5)), ((1,), (10, 7, 3)), ((-2,), (4, 8)),
        (-1, (6, 3)), ((0, 2), (3, 4, 5)), ((2, 0), (7, 8, 9))]
      for keepdims in [False, True]
      for ord in (
          [None, 0, 1, 2, 3, -1, -2, -3, np.inf, -np.inf]
          if (axis is None and len(shape) == 1) or
             isinstance(axis, int) or
             (isinstance(axis, tuple) and len(axis) == 1)
          else [None, 'fro', 1, 2, -1, -2, np.inf, -np.inf, 'nuc'])
      for dtype in float_types() + complex_types()
      for rng in [jtu.rand_default()]))
  def testNorm(self, shape, dtype, ord, axis, keepdims, rng):
    # TODO(mattjj,phawkins): re-enable after checking internal tests
    self.skipTest("internal test failures")

    if (ord in ('nuc', 2, -2) and isinstance(axis, tuple) and len(axis) == 2 and
        (not FLAGS.jax_test_dut or not FLAGS.jax_test_dut.startswith("cpu") or
         len(shape) != 2)):
      raise SkipTest("No adequate SVD implementation available")

    args_maker = lambda: [rng(shape, dtype)]
    onp_fn = partial(onp.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
    np_fn = partial(np.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
    self._CheckAgainstNumpy(onp_fn, np_fn, args_maker,
                            check_dtypes=True, tol=1e-3)
    self._CompileAndCheck(np_fn, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_n={}_full_matrices={}_compute_uv={}".format(
          jtu.format_shape_dtype_string((m, n), dtype), full_matrices, compute_uv),
       "m": m, "n": n, "dtype": dtype, "full_matrices": full_matrices,
       "compute_uv": compute_uv, "rng": rng}
      for m in [2, 7, 29, 53]
      for n in [2, 7, 29, 53]
      for dtype in float_types() + complex_types()
      for full_matrices in [False, True]
      for compute_uv in [False, True]
      for rng in [jtu.rand_default()]))
  @jtu.skip_on_devices("gpu", "tpu")
  def testSVD(self, m, n, dtype, full_matrices, compute_uv, rng):
    args_maker = lambda: [rng((m, n), dtype)]

    # Norm, adjusted for dimension and type.
    def norm(x):
      norm = onp.linalg.norm(x, axis=(-2, -1))
      return norm / (max(m, n) * onp.finfo(dtype).eps)

    a, = args_maker()
    out = np.linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)

    if compute_uv:
      # Check the reconstructed matrices
      if full_matrices:
        k = min(m, n)
        if m < n:
          self.assertTrue(onp.all(norm(a - onp.matmul(out[1] * out[0], out[2][:k, :])) < 50))
        else:
          self.assertTrue(onp.all(norm(a - onp.matmul(out[1] * out[0][:, :k], out[2])) < 50))
      else:
          self.assertTrue(onp.all(norm(a - onp.matmul(out[1] * out[0], out[2])) < 50))

      # Check the unitary properties of the singular vector matrices.
      self.assertTrue(onp.all(norm(onp.eye(out[0].shape[1]) - onp.matmul(onp.conj(T(out[0])), out[0])) < 10))
      if m >= n:
        self.assertTrue(onp.all(norm(onp.eye(out[2].shape[1]) - onp.matmul(onp.conj(T(out[2])), out[2])) < 10))
      else:
        self.assertTrue(onp.all(norm(onp.eye(out[2].shape[0]) - onp.matmul(out[2], onp.conj(T(out[2])))) < 20))

    else:
      self.assertTrue(onp.allclose(onp.linalg.svd(a, compute_uv=False), onp.asarray(out), atol=1e-4, rtol=1e-4))

    self._CompileAndCheck(partial(np.linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv),
                          args_maker, check_dtypes=True)
    if not full_matrices:
      svd = partial(np.linalg.svd, full_matrices=False)
      jtu.check_jvp(svd, partial(jvp, svd), (a,), atol=1e-1 if FLAGS.jax_enable_x64 else jtu.ATOL)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_fullmatrices={}".format(
          jtu.format_shape_dtype_string(shape, dtype), full_matrices),
       "shape": shape, "dtype": dtype, "full_matrices": full_matrices,
       "rng": rng}
      for shape in [(1, 1), (3, 4), (2, 10, 5), (2, 200, 100)]
      for dtype in float_types()
      for full_matrices in [False, True]
      for rng in [jtu.rand_default()]))
  @jtu.skip_on_devices("cpu")
  def testQr(self, shape, dtype, full_matrices, rng):
    m, n = shape[-2:]

    if full_matrices:
      mode, k = "complete", m
    else:
      mode, k = "reduced", min(m, n)

    a = rng(shape, dtype)
    lq, lr = np.linalg.qr(a, mode=mode)

    # onp.linalg.qr doesn't support broadcasting. But it seems like an
    # inevitable extension so we support it in our version.
    nq = onp.zeros(shape[:-2] + (m, k), dtype)
    nr = onp.zeros(shape[:-2] + (k, n), dtype)
    for index in onp.ndindex(*shape[:-2]):
      nq[index], nr[index] = onp.linalg.qr(a[index], mode=mode)

    max_rank = max(m, n)

    # Norm, adjusted for dimension and type.
    def norm(x):
      n = onp.linalg.norm(x, axis=(-2, -1))
      return n / (max_rank * onp.finfo(dtype).eps)

    def compare_orthogonal(q1, q2):
      # Q is unique up to sign, so normalize the sign first.
      sum_of_ratios = onp.sum(onp.divide(q1, q2), axis=-2, keepdims=True)
      phases = onp.divide(sum_of_ratios, onp.abs(sum_of_ratios))
      q1 *= phases
      self.assertTrue(onp.all(norm(q1 - q2) < 30))

    # Check a ~= qr
    self.assertTrue(onp.all(norm(a - onp.matmul(lq, lr)) < 30))

    # Compare the first 'k' vectors of Q; the remainder form an arbitrary
    # orthonormal basis for the null space.
    compare_orthogonal(nq[..., :k], lq[..., :k])

    # Check that q is close to unitary.
    self.assertTrue(onp.all(norm(onp.eye(k) - onp.matmul(T(lq), lq)) < 5))

    if not full_matrices and m >= n:
        jtu.check_jvp(np.linalg.qr, partial(jvp, np.linalg.qr), (a,))

  @jtu.skip_on_devices("gpu", "tpu")
  def testQrBatching(self):
    shape = (10, 4, 5)
    dtype = np.float32
    rng = jtu.rand_default()
    args = rng(shape, np.float32)
    qs, rs = vmap(jsp.linalg.qr)(args)
    self.assertTrue(onp.all(onp.linalg.norm(args - onp.matmul(qs, rs)) < 1e-3))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs={}_rhs={}".format(
           jtu.format_shape_dtype_string(lhs_shape, dtype),
           jtu.format_shape_dtype_string(rhs_shape, dtype)),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "rng": rng}
      for lhs_shape, rhs_shape in [
          ((1, 1), (1, 1)),
          ((4, 4), (4,)),
          ((8, 8), (8, 4)),
      ]
      for dtype in float_types() + complex_types()
      for rng in [jtu.rand_default()]))
  # TODO(phawkins): enable when there is an LU implementation for GPU/TPU.
  @jtu.skip_on_devices("gpu", "tpu")
  def testSolve(self, lhs_shape, rhs_shape, dtype, rng):
    args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]

    self._CheckAgainstNumpy(onp.linalg.solve, np.linalg.solve, args_maker,
                            check_dtypes=True, tol=1e-3)
    self._CompileAndCheck(np.linalg.solve, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype, "rng": rng}
      for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (5, 5, 5)]
      for dtype in float_types()
      for rng in [jtu.rand_default()]))
  def testInv(self, shape, dtype, rng):
    def args_maker():
      invertible = False
      while not invertible:
        a = rng(shape, dtype)
        try:
          onp.linalg.inv(a)
          invertible = True
        except onp.linalg.LinAlgError:
          pass
      return [a]

    self._CheckAgainstNumpy(onp.linalg.inv, np.linalg.inv, args_maker,
                            check_dtypes=True, tol=1e-3)
    self._CompileAndCheck(np.linalg.inv, args_maker, check_dtypes=True)

  # Regression test for incorrect type for eigenvalues of a complex matrix.
  @jtu.skip_on_devices("gpu", "tpu")
  def testIssue669(self):
    def test(x):
      val, vec = np.linalg.eigh(x)
      return np.real(np.sum(val))

    grad_test_jc = jit(grad(jit(test)))
    xc = onp.eye(3, dtype=onp.complex)
    self.assertAllClose(xc, grad_test_jc(xc), check_dtypes=True)
Beispiel #4
0
 def testArgminmax(self, op, shape, dtype, dim, bdims):
   rng = jtu.rand_default(self.rng())
   fun = lambda operand: op(operand, dim, np.int32)
   self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng)
Beispiel #5
0
 def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, bdims):
   fun = partial(lax.scatter_add, dimension_numbers=dnums)
   self._CheckBatching(fun, 5, bdims, [arg_shape, idxs.shape, update_shape],
                       [dtype, idxs.dtype, dtype], jtu.rand_default(self.rng()),
                       rtol={np.float16: 5e-3})
Beispiel #6
0
 def test_sum_2d(self):
     self.check(jnp.sum, ['(m, n)'], '', dict(m=2, n=3), [(3, 4)],
                ['float_'], jtu.rand_default(self.rng()))
Beispiel #7
0
 def test_where(self):
     # Requires mask(jit)
     raise SkipTest
     self.check(lambda x: jnp.where(x < 0, x, 0. * x), ['n'], 'n', {'n': 2},
                [(3, )], ['float_'], jtu.rand_default(self.rng()))
Beispiel #8
0
 def testClassUnaryOp(self, dtype, shape, op):
     rng = jtu.rand_default(self.rng())
     args = (rng(shape, dtype), )
     class_op = lambda x: op(_DoubleDouble(x)).to_array()
     self.assertAllClose(op(*args), class_op(*args))
Beispiel #9
0
 def testBinaryOp(self, dtype, shape, op):
     rng = jtu.rand_default(self.rng())
     op_doubled = doubledouble(op)
     args = rng(shape, dtype), rng(shape, dtype)
     self.assertAllClose(op(*args), op_doubled(*args))
Beispiel #10
0
 def testEighGradPrecision(self):
     rng = jtu.rand_default()
     a = rng((3, 3), onp.float32)
     jtu.assert_dot_precision(lax.Precision.HIGHEST,
                              partial(jvp, np.linalg.eigh), (a, ), (a, ))
Beispiel #11
0
class ScipyLinalgTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for shape in [(1, 1), (4, 5), (10, 5), (50, 50)]
                            for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    def testLu(self, shape, dtype, rng):
        _skip_if_unsupported_type(dtype)
        args_maker = lambda: [rng(shape, dtype)]
        x, = args_maker()
        p, l, u = jsp.linalg.lu(x)
        self.assertAllClose(x,
                            onp.matmul(p, onp.matmul(l, u)),
                            check_dtypes=True)
        self._CompileAndCheck(jsp.linalg.lu, args_maker, check_dtypes=True)

    def testLuOfSingularMatrix(self):
        x = np.array([[-1., 3. / 2], [2. / 3, -1.]], dtype=onp.float32)
        p, l, u = jsp.linalg.lu(x)
        self.assertAllClose(x,
                            onp.matmul(p, onp.matmul(l, u)),
                            check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for shape in [(1, 1), (4, 5), (10, 5), (10, 10), (6, 7, 7)]
                            for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    @jtu.skip_on_devices("tpu")  # TODO(phawkins): precision problems on TPU.
    def testLuGrad(self, shape, dtype, rng):
        _skip_if_unsupported_type(dtype)
        a = rng(shape, dtype)
        lu = vmap(jsp.linalg.lu) if len(shape) > 2 else jsp.linalg.lu
        jtu.check_grads(lu, (a, ), 2, atol=5e-2, rtol=1e-1)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for shape in [(4, 5), (6, 5)] for dtype in [np.float32]
                            for rng in [jtu.rand_default()]))
    def testLuBatching(self, shape, dtype, rng):
        _skip_if_unsupported_type(dtype)
        args = [rng(shape, np.float32) for _ in range(10)]
        expected = list(osp.linalg.lu(x) for x in args)
        ps = onp.stack([out[0] for out in expected])
        ls = onp.stack([out[1] for out in expected])
        us = onp.stack([out[2] for out in expected])

        actual_ps, actual_ls, actual_us = vmap(jsp.linalg.lu)(np.stack(args))
        self.assertAllClose(ps, actual_ps, check_dtypes=True)
        self.assertAllClose(ls, actual_ls, check_dtypes=True)
        self.assertAllClose(us, actual_us, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_n={}".format(jtu.format_shape_dtype_string((n, n), dtype)),
            "n":
            n,
            "dtype":
            dtype,
            "rng":
            rng
        } for n in [1, 4, 5, 200] for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    def testLuFactor(self, n, dtype, rng):
        _skip_if_unsupported_type(dtype)
        args_maker = lambda: [rng((n, n), dtype)]

        x, = args_maker()
        lu, piv = jsp.linalg.lu_factor(x)
        l = onp.tril(lu, -1) + onp.eye(n, dtype=dtype)
        u = onp.triu(lu)
        for i in range(n):
            x[[i, piv[i]], ] = x[[piv[i], i], ]
        self.assertAllClose(x, onp.matmul(l, u), check_dtypes=True, rtol=1e-3)
        self._CompileAndCheck(jsp.linalg.lu_factor,
                              args_maker,
                              check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_lhs={}_rhs={}_trans={}".format(
                jtu.format_shape_dtype_string(lhs_shape, dtype),
                jtu.format_shape_dtype_string(rhs_shape, dtype), trans),
            "lhs_shape":
            lhs_shape,
            "rhs_shape":
            rhs_shape,
            "dtype":
            dtype,
            "trans":
            trans,
            "rng":
            rng
        } for lhs_shape, rhs_shape in [
            ((1, 1), (1, 1)),
            ((4, 4), (4, )),
            ((8, 8), (8, 4, 2)),
        ] for trans in [0, 1, 2] for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    def testLuSolve(self, lhs_shape, rhs_shape, dtype, trans, rng):
        _skip_if_unsupported_type(dtype)
        osp_fun = lambda lu, piv, rhs: osp.linalg.lu_solve(
            (lu, piv), rhs, trans=trans)
        jsp_fun = lambda lu, piv, rhs: jsp.linalg.lu_solve(
            (lu, piv), rhs, trans=trans)

        def args_maker():
            a = rng(lhs_shape, dtype)
            lu, piv = osp.linalg.lu_factor(a)
            return [lu, piv, rng(rhs_shape, dtype)]

        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-3)
        self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_lhs={}_rhs={}_sym_pos={}_lower={}".format(
                jtu.format_shape_dtype_string(lhs_shape, dtype),
                jtu.format_shape_dtype_string(rhs_shape, dtype), sym_pos,
                lower),
            "lhs_shape":
            lhs_shape,
            "rhs_shape":
            rhs_shape,
            "dtype":
            dtype,
            "sym_pos":
            sym_pos,
            "lower":
            lower,
            "rng":
            rng
        } for lhs_shape, rhs_shape in [
            ((1, 1), (1, 1)),
            ((4, 4), (4, )),
            ((8, 8), (8, 4)),
        ] for sym_pos, lower in [
            (False, False),
            (True, False),
            (True, True),
        ] for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    def testSolve(self, lhs_shape, rhs_shape, dtype, sym_pos, lower, rng):
        _skip_if_unsupported_type(dtype)
        if (sym_pos and onp.issubdtype(dtype, onp.complexfloating)
                and jtu.device_under_test() == "tpu"):
            raise unittest.SkipTest(
                "Complex Cholesky decomposition not implemented on TPU")
        osp_fun = lambda lhs, rhs: osp.linalg.solve(
            lhs, rhs, sym_pos=sym_pos, lower=lower)
        jsp_fun = lambda lhs, rhs: jsp.linalg.solve(
            lhs, rhs, sym_pos=sym_pos, lower=lower)

        def args_maker():
            a = rng(lhs_shape, dtype)
            if sym_pos:
                a = onp.matmul(a, onp.conj(T(a)))
                a = onp.tril(a) if lower else onp.triu(a)
            return [a, rng(rhs_shape, dtype)]

        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-3)
        self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_lhs={}_rhs={}_lower={}_transposea={}_unit_diagonal={}".format(
                jtu.format_shape_dtype_string(lhs_shape, dtype),
                jtu.format_shape_dtype_string(rhs_shape, dtype), lower,
                transpose_a, unit_diagonal),
            "lower":
            lower,
            "transpose_a":
            transpose_a,
            "unit_diagonal":
            unit_diagonal,
            "lhs_shape":
            lhs_shape,
            "rhs_shape":
            rhs_shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for lower in [False, True] for transpose_a in [False, True]
                            for unit_diagonal in [False, True]
                            for lhs_shape, rhs_shape in [
                                ((4, 4), (4, )),
                                ((4, 4), (4, 3)),
                                ((2, 8, 8), (2, 8, 10)),
                            ] for dtype in float_types
                            for rng in [jtu.rand_default()]))
    def testSolveTriangular(self, lower, transpose_a, unit_diagonal, lhs_shape,
                            rhs_shape, dtype, rng):
        _skip_if_unsupported_type(dtype)
        k = rng(lhs_shape, dtype)
        l = onp.linalg.cholesky(
            onp.matmul(k, T(k)) + lhs_shape[-1] * onp.eye(lhs_shape[-1]))
        l = l.astype(k.dtype)
        b = rng(rhs_shape, dtype)

        if unit_diagonal:
            a = onp.tril(l, -1) + onp.eye(lhs_shape[-1], dtype=dtype)
        else:
            a = l
        a = a if lower else T(a)

        inv = onp.linalg.inv(T(a) if transpose_a else a).astype(a.dtype)
        if len(lhs_shape) == len(rhs_shape):
            onp_ans = onp.matmul(inv, b)
        else:
            onp_ans = onp.einsum("...ij,...j->...i", inv, b)

        # The standard scipy.linalg.solve_triangular doesn't support broadcasting.
        # But it seems like an inevitable extension so we support it.
        ans = jsp.linalg.solve_triangular(l if lower else T(l),
                                          b,
                                          trans=1 if transpose_a else 0,
                                          lower=lower,
                                          unit_diagonal=unit_diagonal)

        self.assertAllClose(onp_ans, ans, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_lhs={}_rhs={}_lower={}_transposea={}_unit_diagonal={}".
                format(jtu.format_shape_dtype_string(lhs_shape, dtype),
                       jtu.format_shape_dtype_string(rhs_shape, dtype), lower,
                       transpose_a, unit_diagonal),
                "lower":
                lower,
                "transpose_a":
                transpose_a,
                "unit_diagonal":
                unit_diagonal,
                "lhs_shape":
                lhs_shape,
                "rhs_shape":
                rhs_shape,
                "dtype":
                dtype,
                "rng":
                rng
            } for lower in [False, True] for unit_diagonal in [False, True]
            for dtype in float_types + complex_types for transpose_a in (
                [0, 1] if onp.issubdtype(dtype, np.floating) else [0, 1, 2])
            for lhs_shape, rhs_shape in [
                ((4, 4), (4, )),
                ((4, 4), (4, 3)),
                ((2, 8, 8), (2, 8, 10)),
            ] for rng in [jtu.rand_default()]))
    @jtu.skip_on_devices("tpu")  # TODO(phawkins): Test fails on TPU.
    def testSolveTriangularGrad(self, lower, transpose_a, unit_diagonal,
                                lhs_shape, rhs_shape, dtype, rng):
        _skip_if_unsupported_type(dtype)
        A = np.tril(
            rng(lhs_shape, dtype) + 5 * onp.eye(lhs_shape[-1], dtype=dtype))
        A = A if lower else T(A)
        B = rng(rhs_shape, dtype)
        f = partial(jsp.linalg.solve_triangular,
                    lower=lower,
                    trans=transpose_a,
                    unit_diagonal=unit_diagonal)
        jtu.check_grads(f, (A, B), 2, rtol=2e-2, eps=1e-3)
Beispiel #12
0
 def testIfftshift(self, shape, dtype, axes):
     rng = jtu.rand_default(self.rng())
     args_maker = lambda: (rng(shape, dtype), )
     jnp_fn = lambda arg: jnp.fft.ifftshift(arg, axes=axes)
     np_fn = lambda arg: np.fft.ifftshift(arg, axes=axes)
     self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker)
Beispiel #13
0
class IndexedUpdateTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "{}_inshape={}_indexer={}_update={}_op={}".format(
                    name, jtu.format_shape_dtype_string(shape, dtype), indexer,
                    jtu.format_shape_dtype_string(update_shape, update_dtype),
                    op.name),
                "shape":
                shape,
                "dtype":
                dtype,
                "rng":
                rng,
                "indexer":
                indexer,
                "update_shape":
                update_shape,
                "update_dtype":
                update_dtype,
                "op":
                op
            } for name, index_specs in STATIC_INDEXING_TESTS
            for shape, indexer in index_specs for op in UpdateOps
            for dtype in (
                all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
            for update_shape in _broadcastable_shapes(
                _update_shape(shape, indexer)) for update_dtype in (
                    [dtype] if op == UpdateOps.ADD else all_dtypes)
            for rng in [jtu.rand_default()]))
    def testStaticIndexing(self, shape, dtype, update_shape, update_dtype, rng,
                           indexer, op):
        args_maker = lambda: [
            rng(shape, dtype),
            rng(update_shape, update_dtype)
        ]
        onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y)
        jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
        self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True)
        self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "{}_inshape={}_indexer={}_update={}_op={}".format(
                    name, jtu.format_shape_dtype_string(shape, dtype), indexer,
                    jtu.format_shape_dtype_string(update_shape, update_dtype),
                    op.name),
                "shape":
                shape,
                "dtype":
                dtype,
                "rng":
                rng,
                "indexer":
                indexer,
                "update_shape":
                update_shape,
                "update_dtype":
                update_dtype,
                "op":
                op
            } for name, index_specs in ADVANCED_INDEXING_TESTS_NO_REPEATS
            for shape, indexer in index_specs for op in UpdateOps
            for dtype in (
                all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
            for update_shape in _broadcastable_shapes(
                _update_shape(shape, indexer)) for update_dtype in (
                    [dtype] if op == UpdateOps.ADD else all_dtypes)
            for rng in [jtu.rand_default()]))
    def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
                             rng, indexer, op):
        args_maker = lambda: [
            rng(shape, dtype),
            rng(update_shape, update_dtype)
        ]
        onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y)
        jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
        self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True)
        self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "{}_inshape={}_indexer={}_update={}_op={}".format(
                    name, jtu.format_shape_dtype_string(shape, dtype), indexer,
                    jtu.format_shape_dtype_string(update_shape, update_dtype),
                    op.name),
                "shape":
                shape,
                "dtype":
                dtype,
                "rng":
                rng,
                "indexer":
                indexer,
                "update_shape":
                update_shape,
                "update_dtype":
                update_dtype,
                "op":
                op
            } for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS
            for shape, indexer in index_specs for op in UpdateOps
            for dtype in (
                all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
            for update_shape in _broadcastable_shapes(
                _update_shape(shape, indexer)) for update_dtype in (
                    [dtype] if op == UpdateOps.ADD else all_dtypes)
            for rng in [jtu.rand_default()]))
    def testMixedAdvancedIndexing(self, shape, dtype, update_shape,
                                  update_dtype, rng, indexer, op):
        args_maker = lambda: [
            rng(shape, dtype),
            rng(update_shape, update_dtype)
        ]
        onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y)
        jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
        self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True)
        self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "{}_inshape={}_indexer={}_update={}_op={}".format(
                    name, jtu.format_shape_dtype_string(shape, dtype), indexer,
                    jtu.format_shape_dtype_string(update_shape, update_dtype),
                    op.name),
                "shape":
                shape,
                "dtype":
                dtype,
                "rng":
                rng,
                "indexer":
                indexer,
                "update_shape":
                update_shape,
                "update_dtype":
                update_dtype,
                "op":
                op
            } for name, index_specs in STATIC_INDEXING_TESTS
            for shape, indexer in index_specs for op in UpdateOps
            for dtype in float_dtypes
            for update_shape in _broadcastable_shapes(
                _update_shape(shape, indexer)) for update_dtype in (
                    [dtype] if op == UpdateOps.ADD else float_dtypes)
            for rng in [jtu.rand_default()]))
    def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype,
                                rng, indexer, op):
        jax_op = ops.index_update if op == UpdateOps.UPDATE else ops.index_add
        jax_fn = lambda x, y: jax_op(x, indexer, y)
        x = rng(shape, dtype)
        y = rng(update_shape, update_dtype)
        check_grads(jax_fn, (x, y), 2, rtol=1e-3, atol=1e-3, eps=1.)

    def testSegmentSumBehavior(self):
        # testAdvancedIndexing compares against NumPy, and as a result doesn't check
        # repeated indices. This test is just a simple manual check, based on
        # https://www.tensorflow.org/api_docs/python/tf/math/segment_sum
        data = onp.array([5, 1, 7, 2, 3, 4, 1, 3])
        segment_ids = onp.array([0, 0, 0, 1, 2, 2, 3, 3])

        ans = ops.index_add(onp.zeros(onp.max(segment_ids) + 1), segment_ids,
                            data)
        expected = onp.array([13, 2, 7, 4])
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testSegmentSum(self):
        data = onp.array([5, 1, 7, 2, 3, 4, 1, 3])
        segment_ids = onp.array([0, 0, 0, 1, 2, 2, 3, 3])

        # test with explicit num_segments
        ans = ops.segment_sum(data, segment_ids, num_segments=4)
        expected = onp.array([13, 2, 7, 4])
        self.assertAllClose(ans, expected, check_dtypes=False)

        # test without explicit num_segments
        ans = ops.segment_sum(data, segment_ids)
        expected = onp.array([13, 2, 7, 4])
        self.assertAllClose(ans, expected, check_dtypes=False)
Beispiel #14
0
class IndexingTest(jtu.JaxTestCase):
    """Tests for Numpy indexing translation rules."""
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "{}_inshape={}_indexer={}".format(
                name, jtu.format_shape_dtype_string(shape, dtype), indexer),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng,
            "indexer":
            indexer
        } for name, index_specs in STATIC_INDEXING_TESTS
                            for shape, indexer in index_specs
                            for dtype in all_dtypes
                            for rng in [jtu.rand_default()]))
    def testStaticIndexing(self, shape, dtype, rng, indexer):
        args_maker = lambda: [rng(shape, dtype)]
        fun = lambda x: x[indexer]
        self._CompileAndCheck(fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters({
        "testcase_name":
        "{}_inshape={}_indexer={}".format(
            name, jtu.format_shape_dtype_string(shape, dtype), indexer),
        "shape":
        shape,
        "dtype":
        dtype,
        "rng":
        rng,
        "indexer":
        indexer
    } for name, index_specs in STATIC_INDEXING_GRAD_TESTS
                                    for shape, indexer in index_specs
                                    for dtype in float_dtypes
                                    for rng in [jtu.rand_default()])
    def testStaticIndexingGrads(self, shape, dtype, rng, indexer):
        tol = 1e-2 if onp.finfo(dtype).bits == 32 else None
        arg = rng(shape, dtype)
        fun = lambda x: x[indexer]**2
        check_grads(fun, (arg, ), 2, tol, tol, tol)

    def _ReplaceSlicesWithTuples(self, idx):
        """Helper method to replace slices with tuples for dynamic indexing args."""
        if isinstance(idx, slice):
            triple = idx.start, idx.stop, idx.step
            isnone = [i for i, elt in enumerate(triple) if elt is None]
            zeros = itertools.repeat(0)
            nones = itertools.repeat(None)
            out = lax.subvals(triple, zip(isnone, zeros))
            return out, lambda out: slice(*lax.subvals(out, zip(isnone, nones))
                                          )
        elif isinstance(idx, (tuple, list)) and idx:
            t = type(idx)
            elts, packs = zip(*map(self._ReplaceSlicesWithTuples, idx))
            return elts, lambda elts: t(
                (pack(i) for pack, i in zip(packs, elts)))
        else:
            return idx, lambda x: x

    @parameterized.named_parameters({
        "testcase_name":
        "{}_inshape={}_indexer={}".format(
            name, jtu.format_shape_dtype_string(shape, dtype), indexer),
        "shape":
        shape,
        "dtype":
        dtype,
        "rng":
        rng,
        "indexer":
        indexer
    } for name, index_specs in [
        ("OneSliceIndex", [
            IndexSpec(shape=(5, ), indexer=slice(1, 3)),
            IndexSpec(shape=(5, 4), indexer=slice(1, 3))
        ]),
        ("TwoSliceIndices", [
            IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))),
            IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2)))
        ]),
        ("NonUnitStrides", [
            IndexSpec(shape=(3, ), indexer=slice(None, None, -1)),
            IndexSpec(shape=(3, 3), indexer=slice(0, 3, -2)),
            IndexSpec(shape=(3, 4, 5), indexer=slice(0, 4, 2))
        ]),
        ("OnlyStartOrStopDynamic", [
            IndexSpec(shape=(5, 4), indexer=(slice(None, 3), slice(0, 2))),
            IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None)))
        ]),
    ] for shape, indexer in index_specs for dtype in all_dtypes
                                    for rng in [jtu.rand_default()])
    def testDynamicIndexingWithSlicesErrors(self, shape, dtype, rng, indexer):
        unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)

        @api.jit
        def fun(x, unpacked_indexer):
            indexer = pack_indexer(unpacked_indexer)
            return x[indexer]

        args_maker = lambda: [rng(shape, dtype), unpacked_indexer]
        self.assertRaises(IndexError, lambda: fun(*args_maker()))

    @parameterized.named_parameters({
        "testcase_name":
        "{}_inshape={}_indexer={}".format(
            name, jtu.format_shape_dtype_string(shape, dtype), indexer),
        "shape":
        shape,
        "dtype":
        dtype,
        "rng":
        rng,
        "indexer":
        indexer
    } for name, index_specs in [
        ("OneIntIndex", [
            IndexSpec(shape=(3, ), indexer=1),
            IndexSpec(shape=(3, 3), indexer=0),
            IndexSpec(shape=(3, 4, 5), indexer=2),
            IndexSpec(shape=(3, ), indexer=-1),
            IndexSpec(shape=(3, ), indexer=-2)
        ]),
        ("TwoIntIndices", [
            IndexSpec(shape=(3, 3), indexer=(2, 1)),
            IndexSpec(shape=(3, 4, 5), indexer=(1, 2)),
            IndexSpec(shape=(3, 4, 5), indexer=(-1, 2))
        ]),
        ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]),
    ] for shape, indexer in index_specs for dtype in all_dtypes
                                    for rng in [jtu.rand_default()])
    def testDynamicIndexingWithIntegers(self, shape, dtype, rng, indexer):
        unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)

        def fun(x, unpacked_indexer):
            indexer = pack_indexer(unpacked_indexer)
            return x[indexer]

        args_maker = lambda: [rng(shape, dtype), unpacked_indexer]
        self._CompileAndCheck(fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters({
        "testcase_name":
        "{}_inshape={}_indexer={}".format(
            name, jtu.format_shape_dtype_string(shape, dtype), indexer),
        "shape":
        shape,
        "dtype":
        dtype,
        "rng":
        rng,
        "indexer":
        indexer
    } for name, index_specs in [
        ("OneIntIndex", [
            IndexSpec(shape=(3, ), indexer=1),
            IndexSpec(shape=(3, 3), indexer=0),
            IndexSpec(shape=(3, 4, 5), indexer=2),
            IndexSpec(shape=(3, ), indexer=-1),
            IndexSpec(shape=(3, ), indexer=-2),
        ]),
        ("TwoIntIndices", [
            IndexSpec(shape=(3, 3), indexer=(2, 1)),
            IndexSpec(shape=(3, 4, 5), indexer=(1, 2)),
            IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)),
        ]),
        ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]),
    ] for shape, indexer in index_specs for dtype in float_dtypes
                                    for rng in [jtu.rand_default()])
    def testDynamicIndexingWithIntegersGrads(self, shape, dtype, rng, indexer):
        tol = 1e-2 if onp.finfo(dtype).bits == 32 else None
        unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)

        @api.jit
        def fun(unpacked_indexer, x):
            indexer = pack_indexer(unpacked_indexer)
            return x[indexer]

        arr = rng(shape, dtype)
        check_grads(partial(fun, unpacked_indexer), (arr, ), 2, tol, tol, tol)

    @parameterized.named_parameters({
        "testcase_name":
        "{}_inshape={}_indexer={}".format(
            name, jtu.format_shape_dtype_string(shape, dtype), indexer),
        "shape":
        shape,
        "dtype":
        dtype,
        "rng":
        rng,
        "indexer":
        indexer
    } for name, index_specs in ADVANCED_INDEXING_TESTS
                                    for shape, indexer in index_specs
                                    for dtype in all_dtypes
                                    for rng in [jtu.rand_default()])
    def testAdvancedIntegerIndexing(self, shape, dtype, rng, indexer):
        args_maker = lambda: [rng(shape, dtype), indexer]
        fun = lambda x, idx: x[idx]
        self._CompileAndCheck(fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters({
        "testcase_name":
        "{}_inshape={}_indexer={}".format(
            name, jtu.format_shape_dtype_string(shape, dtype), indexer),
        "shape":
        shape,
        "dtype":
        dtype,
        "rng":
        rng,
        "indexer":
        indexer
    } for name, index_specs in [
        ("One1DIntArrayIndex", [
            IndexSpec(shape=(3, ), indexer=onp.array([0, 1])),
            IndexSpec(shape=(3, 3), indexer=onp.array([1, 2, 1])),
            IndexSpec(shape=(3, 4, 5), indexer=onp.array([0, 2, 0, 1])),
            IndexSpec(shape=(3, ), indexer=onp.array([-1, 1])),
            IndexSpec(shape=(3, ), indexer=onp.array([-2, -1])),
        ]),
        ("One2DIntArrayIndex", [
            IndexSpec(shape=(3, ), indexer=onp.array([[0, 0]])),
            IndexSpec(shape=(3,
                             3), indexer=onp.array([[1, 2, 1], [0, 1, -1]])),
            IndexSpec(shape=(3, 4, 5),
                      indexer=onp.array([[0, 2, 0, 1], [-1, -2, 1, 0]])),
        ]),
        ("Two1DIntArrayIndicesNoBroadcasting", [
            IndexSpec(shape=(3, 3),
                      indexer=[onp.array([0, 1]),
                               onp.array([1, 2])]),
            IndexSpec(
                shape=(3,
                       4, 5),
                indexer=[onp.array([0, 2, 0, 1]),
                         onp.array([-1, 0, -1, 2])]),
        ]),
        ("Two1DIntArrayIndicesWithBroadcasting", [
            IndexSpec(shape=(3, 3),
                      indexer=[onp.array([[0, 1]]),
                               onp.array([1, 2])]),
            IndexSpec(shape=(3, 4, 5),
                      indexer=[
                          onp.array([[0, 2, 0, 1]]),
                          onp.array([-1, 0, -1, 2])
                      ]),
        ]),
        ("ListOfPythonInts", [
            IndexSpec(shape=(3, ), indexer=[0, 1, 0]),
            IndexSpec(shape=(3, 4, 5), indexer=[0, -1]),
        ]),
        ("ListOfListsOfPythonInts", [
            IndexSpec(shape=(3, 4, 5), indexer=[[0, 1]]),
            IndexSpec(shape=(3, 4, 5), indexer=[[[0], [-1]], [[2, 3, 0, 3]]]),
        ]),
        ("ListOfPythonIntsAndIntArrays", [
            IndexSpec(shape=(3, 4, 5), indexer=[0, onp.array([0, 1])]),
            IndexSpec(shape=(3, 4, 5),
                      indexer=[0, 1, onp.array([[2, 3, 0, 3]])]),
        ]),
        ("ListOfListsOfPythonIntsAndIntArrays", [
            IndexSpec(shape=(3, 4, 5), indexer=[[0, 1], onp.array([0])]),
            IndexSpec(shape=(3, 4, 5),
                      indexer=[[[0], [-1]],
                               onp.array([[2, 3, 0, 3]])]),
        ]),
    ] for shape, indexer in index_specs for dtype in float_dtypes
                                    for rng in [jtu.rand_default()])
    def testAdvancedIntegerIndexingGrads(self, shape, dtype, rng, indexer):
        tol = 1e-2 if onp.finfo(dtype).bits == 32 else None
        arg = rng(shape, dtype)
        fun = lambda x: x[indexer]**2
        check_grads(fun, (arg, ), 2, tol, tol, tol)

    @parameterized.named_parameters({
        "testcase_name":
        "{}_inshape={}_indexer={}".format(
            name, jtu.format_shape_dtype_string(shape, dtype), indexer),
        "shape":
        shape,
        "dtype":
        dtype,
        "rng":
        rng,
        "indexer":
        indexer
    } for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS
                                    for shape, indexer in index_specs
                                    for dtype in all_dtypes
                                    for rng in [jtu.rand_default()])
    def testMixedAdvancedIntegerIndexing(self, shape, dtype, rng, indexer):
        indexer_with_dummies = [
            e if isinstance(e, onp.ndarray) else () for e in indexer
        ]
        substitutes = [(i, e) for i, e in enumerate(indexer)
                       if not isinstance(e, onp.ndarray)]
        args_maker = lambda: [rng(shape, dtype), indexer_with_dummies]

        def fun(x, indexer_with_dummies):
            idx = type(indexer)(lax.subvals(indexer_with_dummies, substitutes))
            return x[idx]

        self._CompileAndCheck(fun, args_maker, check_dtypes=True)

    def testAdvancedIndexingManually(self):
        x = onp.random.RandomState(0).randn(3, 4, 5)
        index_array = onp.array([0, 2, -1, 0])

        op = lambda x, index_array: x[..., index_array, :]
        cop = api.jit(op)

        a1 = op(x, index_array)
        a2 = cop(x, index_array)

        self.assertAllClose(a1, a2, check_dtypes=True)

        op = lambda x, index_array: x[..., index_array, :, index_array, None]
        cop = api.jit(op)

        a1 = op(x, index_array)
        a2 = cop(x, index_array)

        self.assertAllClose(a1, a2, check_dtypes=True)

        op = lambda x, index_array: x[index_array, ..., index_array[:, None],
                                      None]
        cop = api.jit(op)

        a1 = op(x, index_array)
        a2 = cop(x, index_array)

        self.assertAllClose(a1, a2, check_dtypes=True)

    def testUnpacking(self):
        def foo(x):
            a, b, c = x
            return a + b + c

        cfoo = api.jit(foo)

        a1 = foo(onp.arange(3))
        a2 = cfoo(onp.arange(3))

        self.assertAllClose(a1, a2, check_dtypes=True)

    def testBooleanIndexingArray1D(self):
        idx = onp.array([True, True, False])
        x = api.device_put(onp.arange(3))
        ans = x[idx]
        expected = onp.arange(3)[idx]
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testBooleanIndexingList1D(self):
        idx = [True, True, False]
        x = api.device_put(onp.arange(3))
        ans = x[idx]
        expected = onp.arange(3)[idx]
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testBooleanIndexingArray2DBroadcast(self):
        idx = onp.array([True, True, False, True])
        x = onp.arange(8).reshape(4, 2)
        ans = api.device_put(x)[idx]
        expected = x[idx]
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testBooleanIndexingList2DBroadcast(self):
        idx = [True, True, False, True]
        x = onp.arange(8).reshape(4, 2)
        ans = api.device_put(x)[idx]
        expected = x[idx]
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testBooleanIndexingArray2D(self):
        idx = onp.array([[True, False], [False, True], [False, False],
                         [True, True]])
        x = onp.arange(8).reshape(4, 2)
        ans = api.device_put(x)[idx]
        expected = x[idx]
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testBooleanIndexingDynamicShapeError(self):
        x = onp.zeros(3)
        i = onp.array([True, True, False])
        self.assertRaises(IndexError, lambda: api.jit(lambda x, i: x[i])(x, i))

    def testIssue187(self):
        x = lnp.ones((5, 5))
        x[[0, 2, 4], [0, 2, 4]]  # doesn't crash

        x = onp.arange(25).reshape((5, 5))
        ans = api.jit(lambda x: x[[0, 2, 4], [0, 2, 4]])(x)
        expected = x[[0, 2, 4], [0, 2, 4]]
        self.assertAllClose(ans, expected, check_dtypes=False)

    def testJVPOfGradOfIndexing(self):
        # Should return a value, even though we didn't pass a symbolic zero as the
        # index tangent.
        x = lnp.ones((3, 4), lnp.float32)
        i = lnp.ones((3, ), lnp.int32)
        f = lambda x, i: lnp.sum(x[i])
        primals, tangents = api.jvp(api.grad(f), (x, i),
                                    (x, onp.zeros_like(i)))
        expected = onp.broadcast_to(
            onp.array([0, 3, 0], dtype=onp.float32)[:, None], (3, 4))
        self.assertAllClose(expected, primals, check_dtypes=True)
        self.assertAllClose(onp.zeros_like(x), tangents, check_dtypes=True)
Beispiel #15
0
 def test_lax_slice(self):
     self.check(lambda x: lax.slice(x, (1, ), (x.shape[0], )), ['n'],
                'n+-1', {'n': 2}, [(3, )], ['float_'],
                jtu.rand_default(self.rng()))
Beispiel #16
0
class NumpyLinalgTest(jtu.JaxTestCase):

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype, "rng": rng}
      for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (1000, 0, 0)]
      for dtype in float_types()
      for rng in [jtu.rand_default()]))
  def testCholesky(self, shape, dtype, rng):
    def args_maker():
      a = rng(shape, dtype)
      return [onp.matmul(a, np.conj(T(a)))]

    self._CheckAgainstNumpy(onp.linalg.cholesky, np.linalg.cholesky, args_maker,
                            check_dtypes=True, tol=1e-3)
    self._CompileAndCheck(np.linalg.cholesky, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)),
       "n": n, "dtype": dtype, "rng": rng}
      for n in [0, 4, 5, 50]
      for dtype in float_types() | complex_types()
      for rng in [jtu.rand_default()]))
  # TODO(phawkins): enable when there is an LU implementation for GPU/TPU.
  @jtu.skip_on_devices("gpu", "tpu")
  def testDet(self, n, dtype, rng):
    if not hasattr(lapack, "jax_getrf"):
      self.skipTest("No LU implementation available")
    args_maker = lambda: [rng((n, n), dtype)]

    self._CheckAgainstNumpy(onp.linalg.det, np.linalg.det, args_maker,
                            check_dtypes=True, tol=1e-3)
    self._CompileAndCheck(np.linalg.det, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)),
       "n": n, "dtype": dtype, "rng": rng}
      for n in [0, 4, 10, 200]
      for dtype in float_types() | complex_types()
      for rng in [jtu.rand_default()]))
  @jtu.skip_on_devices("gpu", "tpu")
  def testSlogdet(self, n, dtype, rng):
    if not hasattr(lapack, "jax_getrf"):
      self.skipTest("No LU implementation available")
    args_maker = lambda: [rng((n, n), dtype)]

    self._CheckAgainstNumpy(onp.linalg.slogdet, np.linalg.slogdet, args_maker,
                            check_dtypes=True, tol=1e-3)
    self._CompileAndCheck(np.linalg.slogdet, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_n={}_lower={}".format(
           jtu.format_shape_dtype_string((n,n), dtype), lower),
       "n": n, "dtype": dtype, "lower": lower, "rng": rng}
      for n in [0, 4, 5, 50]
      for dtype in float_types() | complex_types()
      for lower in [False, True]
      for rng in [jtu.rand_default()]))
  # TODO(phawkins): enable when there is an eigendecomposition implementation
  # for GPU/TPU.
  @jtu.skip_on_devices("gpu", "tpu")
  def testEigh(self, n, dtype, lower, rng):
    if not hasattr(lapack, "jax_syevd"):
      self.skipTest("No symmetric eigendecomposition implementation available")
    args_maker = lambda: [rng((n, n), dtype)]

    uplo = "L" if lower else "U"

    # Norm, adjusted for dimension and type.
    def norm(x):
      norm = onp.linalg.norm(x, axis=(-2, -1))
      return norm / ((n + 1) * onp.finfo(dtype).eps)

    a, = args_maker()
    a = (a + onp.conj(a.T)) / 2
    w, v = np.linalg.eigh(onp.tril(a) if lower else onp.triu(a), UPLO=uplo)

    self.assertTrue(norm(onp.eye(n) - onp.matmul(onp.conj(T(v)), v)) < 5)
    self.assertTrue(norm(onp.matmul(a, v) - w * v) < 30)

    self._CompileAndCheck(partial(np.linalg.eigh, UPLO=uplo), args_maker,
                          check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_n={}_full_matrices={}_compute_uv={}".format(
          jtu.format_shape_dtype_string((m, n), dtype), full_matrices, compute_uv),
       "m": m, "n": n, "dtype": dtype, "full_matrices": full_matrices,
       "compute_uv": compute_uv, "rng": rng}
      for m in [2, 7, 29, 53]
      for n in [2, 7, 29, 53]
      for dtype in float_types() | complex_types()
      for full_matrices in [False, True]
      for compute_uv in [False, True]
      for rng in [jtu.rand_default()]))
  @jtu.skip_on_devices("gpu", "tpu")
  def testSVD(self, m, n, dtype, full_matrices, compute_uv, rng):
    if not hasattr(lapack, "jax_gesdd"):
      self.skipTest("No singular value decomposition implementation available")

    args_maker = lambda: [rng((m, n), dtype)]

    # Norm, adjusted for dimension and type.
    def norm(x):
      norm = onp.linalg.norm(x, axis=(-2, -1))
      return norm / (max(m, n) * onp.finfo(dtype).eps)

    a, = args_maker()
    out = np.linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)

    if compute_uv:
      # Check the reconstructed matrices
      if full_matrices:
        k = min(m, n)
        if m < n:
          self.assertTrue(onp.all(norm(a - onp.matmul(out[1] * out[0], out[2][:k, :])) < 50))
        else:
          self.assertTrue(onp.all(norm(a - onp.matmul(out[1] * out[0][:, :k], out[2])) < 50))
      else:
          self.assertTrue(onp.all(norm(a - onp.matmul(out[1] * out[0], out[2])) < 50))

      # Check the unitary properties of the singular vector matrices.
      self.assertTrue(onp.all(norm(onp.eye(out[0].shape[1]) - onp.matmul(onp.conj(T(out[0])), out[0])) < 10))
      if m >= n:
        self.assertTrue(onp.all(norm(onp.eye(out[2].shape[1]) - onp.matmul(onp.conj(T(out[2])), out[2])) < 10))
      else:
        self.assertTrue(onp.all(norm(onp.eye(out[2].shape[0]) - onp.matmul(out[2], onp.conj(T(out[2])))) < 20))

    else:
      self.assertTrue(onp.allclose(onp.linalg.svd(a, compute_uv=False), onp.asarray(out), atol=1e-4, rtol=1e-4))

    self._CompileAndCheck(partial(np.linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv),
                          args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_fullmatrices={}".format(
          jtu.format_shape_dtype_string(shape, dtype), full_matrices),
       "shape": shape, "dtype": dtype, "full_matrices": full_matrices,
       "rng": rng}
      for shape in [(1, 1), (3, 4), (2, 10, 5), (2, 200, 100)]
      for dtype in float_types()
      for full_matrices in [False, True]
      for rng in [jtu.rand_default()]))
  @jtu.skip_on_devices("cpu")
  def testQr(self, shape, dtype, full_matrices, rng):
    m, n = shape[-2:]

    if full_matrices:
      mode, k = "complete", m
    else:
      mode, k = "reduced", min(m, n)

    a = rng(shape, dtype)
    lq, lr = np.linalg.qr(a, mode=mode)

    # onp.linalg.qr doesn't support broadcasting. But it seems like an
    # inevitable extension so we support it in our version.
    nq = onp.zeros(shape[:-2] + (m, k), dtype)
    nr = onp.zeros(shape[:-2] + (k, n), dtype)
    for index in onp.ndindex(*shape[:-2]):
      nq[index], nr[index] = onp.linalg.qr(a[index], mode=mode)

    max_rank = max(m, n)

    # Norm, adjusted for dimension and type.
    def norm(x):
      n = onp.linalg.norm(x, axis=(-2, -1))
      return n / (max_rank * onp.finfo(dtype).eps)

    def compare_orthogonal(q1, q2):
      # Q is unique up to sign, so normalize the sign first.
      sum_of_ratios = onp.sum(onp.divide(q1, q2), axis=-2, keepdims=True)
      phases = onp.divide(sum_of_ratios, onp.abs(sum_of_ratios))
      q1 *= phases
      self.assertTrue(onp.all(norm(q1 - q2) < 30))

    # Check a ~= qr
    self.assertTrue(onp.all(norm(a - onp.matmul(lq, lr)) < 30))

    # Compare the first 'k' vectors of Q; the remainder form an arbitrary
    # orthonormal basis for the null space.
    compare_orthogonal(nq[..., :k], lq[..., :k])

    # Check that q is close to unitary.
    self.assertTrue(onp.all(norm(onp.eye(k) - onp.matmul(T(lq), lq)) < 5))

    if not full_matrices and m >= n:
        jtu.check_jvp(np.linalg.qr, partial(jvp, np.linalg.qr), (a,))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs={}_rhs={}".format(
           jtu.format_shape_dtype_string(lhs_shape, dtype),
           jtu.format_shape_dtype_string(rhs_shape, dtype)),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "rng": rng}
      for lhs_shape, rhs_shape in [
          ((1, 1), (1, 1)),
          ((4, 4), (4,)),
          ((8, 8), (8, 4)),
      ]
      for dtype in float_types() | complex_types()
      for rng in [jtu.rand_default()]))
  # TODO(phawkins): enable when there is an LU implementation for GPU/TPU.
  @jtu.skip_on_devices("gpu", "tpu")
  def testSolve(self, lhs_shape, rhs_shape, dtype, rng):
    if not hasattr(lapack, "jax_getrf"):
      self.skipTest("No LU implementation available")
    args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]

    self._CheckAgainstNumpy(onp.linalg.solve, np.linalg.solve, args_maker,
                            check_dtypes=True, tol=1e-3)
    self._CompileAndCheck(np.linalg.solve, args_maker, check_dtypes=True)


  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype, "rng": rng}
      for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (5, 5, 5)]
      for dtype in float_types()
      for rng in [jtu.rand_default()]))
  def testInv(self, shape, dtype, rng):
    def args_maker():
      invertible = False
      while not invertible:
        a = rng(shape, dtype)
        try:
          onp.linalg.inv(a)
          invertible = True
        except onp.linalg.LinAlgError:
          pass
      return [a]

    self._CheckAgainstNumpy(onp.linalg.inv, np.linalg.inv, args_maker,
                            check_dtypes=True, tol=1e-3)
    self._CompileAndCheck(np.linalg.inv, args_maker, check_dtypes=True)
Beispiel #17
0
 def test_transpose(self):
     self.check(lambda x: lax.transpose(x, (1, 0, 2)), ['(a, b, c)'],
                'b, a, c', dict(a=2, b=3, c=4), [(3, 4, 5)], ['float_'],
                jtu.rand_default(self.rng()))
Beispiel #18
0
def op_record(name, nargs, dtypes, rng, test_grad, test_name=None):
  test_name = test_name or name
  return OpRecord(name, nargs, dtypes, rng, test_grad, test_name)

JAX_SPECIAL_FUNCTION_RECORDS = [
    # TODO: digamma has no JVP implemented.
    op_record("digamma", 1, float_dtypes, jtu.rand_positive(), False),
    op_record("erf", 1, float_dtypes, jtu.rand_small_positive(), True),
    op_record("erfc", 1, float_dtypes, jtu.rand_small_positive(), True),
    op_record("erfinv", 1, float_dtypes, jtu.rand_small_positive(), True),
    op_record("expit", 1, float_dtypes, jtu.rand_small_positive(), True),
    # TODO: gammaln has slightly high error.
    op_record("gammaln", 1, float_dtypes, jtu.rand_positive(), False),
    op_record("logit", 1, float_dtypes, jtu.rand_small_positive(), False),
    op_record("log_ndtr", 1, float_dtypes, jtu.rand_default(), True),
    op_record("ndtri", 1, float_dtypes, jtu.rand_uniform(0.05, 0.95), True),
    op_record("ndtr", 1, float_dtypes, jtu.rand_default(), True),
    # TODO(phawkins): gradient of entr yields NaNs.
    op_record("entr", 1, float_dtypes, jtu.rand_default(), False),
]

CombosWithReplacement = itertools.combinations_with_replacement


class LaxBackedScipyTests(jtu.JaxTestCase):
  """Tests for LAX-backed Scipy implementation."""

  def _GetArgsMaker(self, rng, shapes, dtypes):
    return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
Beispiel #19
0
 def test_expit(self):
     raise SkipTest("custom_jvp doesn't work with masking yet")
     self.check(expit, ['n'], 'n', dict(n=3), [(4, )], ['float_'],
                jtu.rand_default(self.rng()))
Beispiel #20
0
class LaxBackedScipyTests(jtu.JaxTestCase):
  """Tests for LAX-backed Scipy implementation."""

  def _GetArgsMaker(self, rng, shapes, dtypes):
    return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_axis={}_keepdims={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
       "rng": jtu.rand_default(), "shape": shape, "dtype": dtype,
       "axis": axis, "keepdims": keepdims}
      for shape in all_shapes for dtype in float_dtypes
      for axis in range(-len(shape), len(shape))
      for keepdims in [False, True]))
  @jtu.skip_on_flag("jax_xla_backend", "xrt")
  def testLogSumExp(self, rng, shape, dtype, axis, keepdims):
    # TODO(mattjj): test autodiff
    def scipy_fun(array_to_reduce):
      return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims)

    def lax_fun(array_to_reduce):
      return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims)

    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(itertools.chain.from_iterable(
    jtu.cases_from_list(
        {"testcase_name": jtu.format_test_name_suffix(
            rec.test_name, shapes, dtypes),
         "rng": rec.rng, "shapes": shapes, "dtypes": dtypes,
         "test_autodiff": rec.test_autodiff,
         "scipy_op": getattr(osp_special, rec.name),
         "lax_op": getattr(lsp_special, rec.name)}
        for shapes in CombosWithReplacement(all_shapes, rec.nargs)
        for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs))
      for rec in JAX_SPECIAL_FUNCTION_RECORDS))
  def testScipySpecialFun(self, scipy_op, lax_op, rng, shapes, dtypes,
                          test_autodiff):
    args_maker = self._GetArgsMaker(rng, shapes, dtypes)
    args = args_maker()
    self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3,
                        check_dtypes=False)
    self._CompileAndCheck(lax_op, args_maker, check_dtypes=True)

    if test_autodiff:
      jtu.check_grads(lax_op, args, order=1, atol=1e-3, rtol=3e-3, eps=1e-3)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_d={}".format(
          jtu.format_shape_dtype_string(shape, dtype), d),
       "rng": jtu.rand_positive(), "shape": shape, "dtype": dtype, "d": d}
      for shape in all_shapes
      for dtype in float_dtypes
      for d in [1, 2, 5]))
  def testMultigammaln(self, rng, shape, dtype, d):
    def scipy_fun(a):
      return osp_special.multigammaln(a, d)

    def lax_fun(a):
      return lsp_special.multigammaln(a, d)

    args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.]
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  def testIssue980(self):
    x = onp.full((4,), -1e20, dtype=onp.float32)
    self.assertAllClose(onp.zeros((4,), dtype=onp.float32),
                        lsp_special.expit(x), check_dtypes=True)
Beispiel #21
0
 def test_reduce(self, operator):
     self.check(operator, ['(m, n)'], '', {
         'm': 3,
         'n': 4
     }, [(4, 5)], ['float_'], jtu.rand_default(self.rng()))
Beispiel #22
0
 def testSparseAttrAccess(self, attr):
     rng = jtu.rand_default(self.rng())
     args_maker = lambda: [make_sparse_array(rng, (10, ), jnp.float32)]
     f = lambda x: getattr(x, attr)
     self._CompileAndCheck(f, args_maker)
Beispiel #23
0
 def testGather(self, shape, dtype, idxs, dnums, slice_sizes, bdims):
   fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
   self._CheckBatching(fun, 5, bdims, [shape, idxs.shape], [dtype, idxs.dtype],
                       jtu.rand_default(self.rng()))
Beispiel #24
0
 def test_where(self):
     self.check(lambda x: jnp.where(x < 0, x, 0. * x), ['n'], 'n', {'n': 2},
                [(3, )], ['float_'], jtu.rand_default(self.rng()))
Beispiel #25
0
int_dtypes = [onp.int32, onp.int64]
bool_dtypes = [onp.bool_]
default_dtypes = float_dtypes + int_dtypes
numeric_dtypes = float_dtypes + complex_dtypes + int_dtypes


OpRecord = collections.namedtuple("OpRecord", ["name", "nargs", "dtypes", "rng",
                                               "diff_modes", "test_name"])


def op_record(name, nargs, dtypes, rng, diff_modes, test_name=None):
  test_name = test_name or name
  return OpRecord(name, nargs, dtypes, rng, diff_modes, test_name)

JAX_ONE_TO_ONE_OP_RECORDS = [
    op_record("abs", 1, default_dtypes, jtu.rand_default(), ["rev"]),
    op_record("add", 2, default_dtypes, jtu.rand_default(), ["rev"]),
    op_record("bitwise_and", 2, default_dtypes, jtu.rand_bool(), []),
    op_record("bitwise_not", 1, default_dtypes, jtu.rand_bool(), []),
    op_record("bitwise_or", 2, default_dtypes, jtu.rand_bool(), []),
    op_record("bitwise_xor", 2, default_dtypes, jtu.rand_bool(), []),
    op_record("ceil", 1, float_dtypes, jtu.rand_default(), []),
    op_record("conj", 1, numeric_dtypes, jtu.rand_default(), ["rev"]),
    op_record("conjugate", 1, numeric_dtypes, jtu.rand_default(), ["rev"]),
    op_record("equal", 2, default_dtypes, jtu.rand_some_equal(), []),
    op_record("exp", 1, numeric_dtypes, jtu.rand_default(), ["rev"]),
    op_record("floor", 1, float_dtypes, jtu.rand_default(), []),
    op_record("greater", 2, default_dtypes, jtu.rand_some_equal(), []),
    op_record("greater_equal", 2, default_dtypes, jtu.rand_some_equal(), []),
    op_record("less", 2, default_dtypes, jtu.rand_some_equal(), []),
    op_record("less_equal", 2, default_dtypes, jtu.rand_some_equal(), []),
Beispiel #26
0
 def test_split(self):
     self.check(lambda x: jnp.split(x, 2), ['2*n'], ['n', 'n'], dict(n=4),
                [(8, )], ['float_'], jtu.rand_default(self.rng()))
     self.check(lambda x: jnp.split(x, [10]), ['n'], ['10', 'n+-10'],
                dict(n=12), [(12, )], ['float_'],
                jtu.rand_default(self.rng()))
Beispiel #27
0
class ScipyLinalgTest(jtu.JaxTestCase):

  # TODO(phawkins): enable when there is an LU implementation for GPU/TPU.
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype, "rng": rng}
      for shape in [(1, 1), (4, 5), (10, 5), (50, 50)]
      for dtype in float_types() + complex_types()
      for rng in [jtu.rand_default()]))
  @jtu.skip_on_devices("gpu", "tpu")
  def testLu(self, shape, dtype, rng):
    args_maker = lambda: [rng(shape, dtype)]

    self._CheckAgainstNumpy(jsp.linalg.lu, osp.linalg.lu, args_maker,
                            check_dtypes=True, tol=1e-3)
    self._CompileAndCheck(jsp.linalg.lu, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype, "rng": rng}
      for shape in [(1, 1), (4, 5), (10, 5), (10, 10)]
      for dtype in float_types() + complex_types()
      for rng in [jtu.rand_default()]))
  # TODO(phawkins): enable when there is an LU implementation for GPU/TPU.
  @jtu.skip_on_devices("gpu", "tpu")
  def testLuGrad(self, shape, dtype, rng):
    a = rng(shape, dtype)

    jtu.check_grads(jsp.linalg.lu, (a,), 2, rtol=1e-1)

  @jtu.skip_on_devices("gpu", "tpu")
  def testLuBatching(self):
    shape = (4, 5)
    dtype = np.float32
    rng = jtu.rand_default()
    args = [rng(shape, np.float32) for _ in range(10)]
    expected = list(osp.linalg.lu(x) for x in args)
    ps = onp.stack([out[0] for out in expected])
    ls = onp.stack([out[1] for out in expected])
    us = onp.stack([out[2] for out in expected])

    actual_ps, actual_ls, actual_us = vmap(jsp.linalg.lu)(np.stack(args))
    self.assertAllClose(ps, actual_ps, check_dtypes=True)
    self.assertAllClose(ls, actual_ls, check_dtypes=True)
    self.assertAllClose(us, actual_us, check_dtypes=True)

  # TODO(phawkins): enable when there is an LU implementation for GPU/TPU.
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)),
       "n": n, "dtype": dtype, "rng": rng}
      for n in [1, 4, 5, 200]
      for dtype in float_types() + complex_types()
      for rng in [jtu.rand_default()]))
  @jtu.skip_on_devices("gpu", "tpu")
  def testLuFactor(self, n, dtype, rng):
    args_maker = lambda: [rng((n, n), dtype)]

    self._CheckAgainstNumpy(jsp.linalg.lu_factor, osp.linalg.lu_factor,
                            args_maker, check_dtypes=True, tol=1e-3)
    self._CompileAndCheck(jsp.linalg.lu_factor, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs={}_rhs={}_sym_pos={}_lower={}".format(
           jtu.format_shape_dtype_string(lhs_shape, dtype),
           jtu.format_shape_dtype_string(rhs_shape, dtype),
           sym_pos, lower),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "sym_pos": sym_pos, "lower": lower, "rng": rng}
      for lhs_shape, rhs_shape in [
          ((1, 1), (1, 1)),
          ((4, 4), (4,)),
          ((8, 8), (8, 4)),
      ]
      for sym_pos, lower in [
        (False, False),
        (True, False),
        (True, True),
      ]
      for dtype in float_types() + complex_types()
      for rng in [jtu.rand_default()]))
  # TODO(phawkins): enable when there is an LU implementation for GPU/TPU.
  @jtu.skip_on_devices("gpu", "tpu")
  def testSolve(self, lhs_shape, rhs_shape, dtype, sym_pos, lower, rng):
    osp_fun = lambda lhs, rhs: osp.linalg.solve(lhs, rhs, sym_pos=sym_pos, lower=lower)
    jsp_fun = lambda lhs, rhs: jsp.linalg.solve(lhs, rhs, sym_pos=sym_pos, lower=lower)

    def args_maker():
      a = rng(lhs_shape, dtype)
      if sym_pos:
        a = onp.matmul(a, onp.conj(T(a)))
        a = onp.tril(a) if lower else onp.triu(a)
      return [a, rng(rhs_shape, dtype)]

    self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker,
                            check_dtypes=True, tol=1e-3)
    self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs={}_rhs={}_lower={}_transposea={}".format(
           jtu.format_shape_dtype_string(lhs_shape, dtype),
           jtu.format_shape_dtype_string(rhs_shape, dtype),
           lower, transpose_a),
       "lower": lower, "transpose_a": transpose_a,
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "rng": rng}
      for lower, transpose_a in itertools.product([False, True], repeat=2)
      for lhs_shape, rhs_shape in [
          ((4, 4), (4,)),
          ((4, 4), (4, 3)),
          ((2, 8, 8), (2, 8, 10)),
      ]
      for dtype in float_types()
      for rng in [jtu.rand_default()]))
  def testSolveTriangular(self, lower, transpose_a, lhs_shape, rhs_shape, dtype,
                          rng):
    k = rng(lhs_shape, dtype)
    l = onp.linalg.cholesky(onp.matmul(k, T(k))
                            + lhs_shape[-1] * onp.eye(lhs_shape[-1]))
    l = l.astype(k.dtype)
    b = rng(rhs_shape, dtype)

    a = l if lower else T(l)
    inv = onp.linalg.inv(T(a) if transpose_a else a).astype(a.dtype)
    if len(lhs_shape) == len(rhs_shape):
      onp_ans = onp.matmul(inv, b)
    else:
      onp_ans = onp.einsum("...ij,...j->...i", inv, b)

    # The standard scipy.linalg.solve_triangular doesn't support broadcasting.
    # But it seems like an inevitable extension so we support it.
    ans = jsp.linalg.solve_triangular(
        l if lower else T(l), b, trans=1 if transpose_a else 0, lower=lower)

    self.assertAllClose(onp_ans, ans, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs={}_rhs={}_lower={}_transposea={}".format(
           jtu.format_shape_dtype_string(lhs_shape, dtype),
           jtu.format_shape_dtype_string(rhs_shape, dtype),
           lower, transpose_a),
       "lower": lower, "transpose_a": transpose_a,
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "rng": rng}
      for lower, transpose_a in itertools.product([False, True], repeat=2)
      for lhs_shape, rhs_shape in [
          ((4, 4), (4,)),
          ((4, 4), (4, 3)),
          ((2, 8, 8), (2, 8, 10)),
      ]
      for dtype in float_types()
      for rng in [jtu.rand_default()]))
  def testSolveTriangularGrad(self, lower, transpose_a, lhs_shape,
                                     rhs_shape, dtype, rng):
    # TODO(frostig): change ensemble to support a bigger rtol
    self.skipTest("rtol does not cover all devices and precision modes")
    A = np.tril(rng(lhs_shape, dtype) + 5 * onp.eye(lhs_shape[-1], dtype=dtype))
    A = A if lower else T(A)
    B = rng(rhs_shape, dtype)
    f = partial(jsp.linalg.solve_triangular, lower=lower,
                trans=1 if transpose_a else 0)
    jtu.check_grads(f, (A, B), 2, rtol=1e-3)
Beispiel #28
0
 def test_mean(self):
     # TODO Shapecheck fails - shape_as_value can't deal with abstract eval yet
     raise SkipTest
     self.check(lambda x: jnp.sum(x) / shape_as_value(x.shape)[0], ['n'],
                '', {'n': 3}, [(4, )], ['float_'],
                jtu.rand_default(self.rng()))
Beispiel #29
0
class BatchingTest(jtu.JaxTestCase):

  def testConstantFunction(self):
    ans = vmap(lambda x: 3)(onp.ones(4))
    expected = 3 * onp.ones(4)
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testNestedBatchingMatMat(self):
    matvec = vmap(np.vdot, in_axes=(0, None))
    matmat = vmap(matvec, in_axes=(None, 1), out_axes=1)

    R = onp.random.RandomState(0).randn
    A = R(4, 3)
    B = R(3, 2)

    ans = matmat(A, B)
    expected = onp.dot(A, B)
    self.assertAllClose(ans, expected, check_dtypes=False)

    # this is a crude check that we only call a single dot
    def pv_like(x):
      aval = ShapedArray(onp.shape(x), onp.result_type(x))
      return pe.PartialVal((aval, unit))

    def make_jaxpr(fun, example_args):
      jaxpr, _, _, _ = trace_to_jaxpr(fun, map(pv_like, example_args))
      return jaxpr

    jaxpr = make_jaxpr(matmat, (A, B))
    self.assertEqual(len(jaxpr.eqns), 1)

  def testPerExampleGradients(self):
    def predict(params, inputs):
      for W, b in params:
        outputs = np.dot(W, inputs) + b
        inputs = np.tanh(outputs)
      return outputs

    def loss(params, data):
      inputs, targets = data
      predictions = predict(params, inputs)
      return np.sum((predictions - targets)**2)

    batch_size = 5
    layer_sizes = [3, 2, 4]

    R = onp.random.RandomState(0).randn
    params = [(R(m, n), R(m))
              for m, n in zip(layer_sizes[1:], layer_sizes[:-1])]

    input_vec = R(3)
    target_vec = R(4)
    datum = (input_vec, target_vec)

    input_batch = R(5, 3)
    target_batch = R(5, 4)
    batch = (input_batch, target_batch)

    ans = vmap(partial(grad(loss), params))(batch)

    for ans_pair, param_pair in zip(ans, params):
      dW, db = ans_pair
      W, b = param_pair

      self.assertEqual(dW.shape, (batch_size,) + W.shape)
      self.assertEqual(db.shape, (batch_size,) + b.shape)

  def testJacobians(self):
    def jacbwd(f, x):
      y, pullback = vjp(f, x)
      std_basis = onp.eye(onp.size(y)).reshape((-1,) + onp.shape(y))
      jac_flat, = vmap(pullback, out_axes=onp.ndim(y))(std_basis)
      return jac_flat.reshape(onp.shape(y) + onp.shape(x))

    def jacfwd(f, x):
      pushfwd = lambda v: jvp(f, (x,), (v,))
      std_basis = onp.eye(onp.size(x)).reshape((-1,) + onp.shape(x))
      y, jac_flat = vmap(pushfwd, out_axes=(None, 0))(std_basis)
      return jac_flat.reshape(onp.shape(y) + onp.shape(x))

    R = onp.random.RandomState(0).randn

    A = R(4, 3)
    b = R(4)
    f = lambda x: np.tanh(np.dot(A, x) + b)

    x = R(3)
    self.assertAllClose(jacfwd(f, x), jacbwd(f, x), check_dtypes=False)

  def testBatchOfCompile(self):
    side = []

    @jit
    def f(x):
      side.append(None)
      return x + x

    g = jit(vmap(f))
    self.assertAllClose(g(onp.ones(2)), 2 * onp.ones(2), check_dtypes=False)
    self.assertEqual(len(side), 1)
    self.assertAllClose(g(2 * onp.ones(2)), 4 * onp.ones(2),
                        check_dtypes=False)
    self.assertEqual(len(side), 1)

  def testSliceLax(self):
    fun = lambda x: lax.slice(x, (2,), (4,))
    R = onp.random.RandomState(0).randn
    x = R(5, 10)

    ans = vmap(fun)(x)
    expected_ans = x[:, 2:4]
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

  def testSliceNumpy(self):
    fun = lambda x: x[:, 2]
    R = onp.random.RandomState(0).randn
    x = R(10, 5, 3, 7)

    ans = vmap(fun)(x)
    expected_ans = x[:, :, 2]
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

  def testRevLax(self):
    fun = lambda x: lax.rev(x, [0])
    R = onp.random.RandomState(0).randn
    x = R(2, 3)

    ans = vmap(fun)(x)
    expected_ans = x[:, ::-1]
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

    ans = vmap(fun, (1,), 1)(x)
    expected_ans = x[::-1, :]
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

  def testRevNumpy(self):
    fun = lambda x: x[:, ::-1]
    R = onp.random.RandomState(0).randn
    x = R(3, 2, 4)

    ans = vmap(fun)(x)
    expected_ans = x[:, :, ::-1]
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

    ans = vmap(fun, (1,), 1)(x)
    expected_ans = x[:, :, ::-1]
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

    ans = vmap(fun, (2,), 2)(x)
    expected_ans = x[:, ::-1, :]
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

  def testNpMaximum(self):
    fun = lambda x: np.maximum(x, 0.0)
    R = onp.random.RandomState(0).randn
    x = R(10, 5, 3, 7)

    ans = vmap(fun)(x)
    expected_ans = onp.maximum(x, 0.0)
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

  def testNpGtrThan(self):
    R = onp.random.RandomState(0).randn
    x = R(10, 5, 3, 7)

    ans = vmap(lambda x: x > 1.0)(x)
    expected_ans = x > 1.0
    self.assertAllClose(ans, expected_ans, check_dtypes=True)

  def testNpMaximumPerExampleGrad(self):
    R = onp.random.RandomState(0).randn
    x = R(10, 5)
    W = R(5, 5)

    fun = lambda W, x: np.sum(np.maximum(np.dot(x, W), 0.0) ** 2)

    ans = vmap(partial(grad(fun), W))(x)

    W_t = np.transpose(W)
    for i in range(10):
      x_ex = x[i:i + 1]

      expected_ans = 2.0 * np.dot(
          np.maximum(np.dot(W_t, np.transpose(x_ex)), 0.0), x_ex)
      expected_ans = np.transpose(expected_ans)

      self.assertAllClose(ans[i], expected_ans, check_dtypes=False)

  def testDotGeneral(self):
    R = onp.random.RandomState(0).randn

    x = R(10, 3, 4, 5)
    y = R(10, 3, 5, 6)
    fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
    ans = vmap(fun)(x, y)
    expected = lax.dot_general(x, y, [((3,), (2,)), ((0, 1), (0, 1))])
    self.assertAllClose(ans, expected, check_dtypes=True)

    x = R(3, 4, 10, 5)
    y = R(3, 10, 5, 6)
    fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
    ans = vmap(fun, in_axes=(2, 1))(x, y)
    fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
    expected = onp.stack([fun(x[..., i, :], y[:, i, ...]) for i in range(10)])
    self.assertAllClose(ans, expected, check_dtypes=True)

    x = R(3, 4, 5, 10)
    y = R(3, 5, 6)
    fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
    ans = vmap(fun, in_axes=(3, None))(x, y)
    fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
    expected = onp.stack([fun(x[..., i], y) for i in range(10)])
    self.assertAllClose(ans, expected, check_dtypes=True)

    x = R(3, 4, 5)
    y = R(3, 5, 10, 6)
    fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
    ans = vmap(fun, in_axes=(None, 2))(x, y)
    fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))])
    expected = onp.stack([fun(x, y[..., i, :]) for i in range(10)])
    self.assertAllClose(ans, expected, check_dtypes=True)

  def testDot(self):
    # these tests are based on @shoyer's notebook studying gufuncs

    def vecvec(a, b):
      dot = np.dot
      for ndim in range(1, max(a.ndim, b.ndim)):
        a_ax = 0 if a.ndim > ndim else None
        b_ax = 0 if b.ndim > ndim else None
        dot = vmap(dot, in_axes=(a_ax, b_ax))
      return dot(a, b)

    assert vecvec(np.zeros((3,)), np.zeros((3,))).shape == ()
    assert vecvec(np.zeros((2, 3)), np.zeros((3,))).shape == (2,)
    assert vecvec(np.zeros((4, 2, 3)), np.zeros((3,))).shape == (4, 2)

  def testDot2(self):
    R = onp.random.RandomState(0).randn
    xs = R(10, 3)
    ys = R(10, 3)
    ans = vmap(np.dot)(xs, ys)
    expected = onp.einsum('ni,ni->n', xs, ys)
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testPad(self):
    R = onp.random.RandomState(0).randn

    fun = lambda x: lax.pad(x, onp.float32(0), [(1, 2, 1)])
    x = R(5, 10).astype(onp.float32)
    ans = vmap(fun)(x)
    expected_ans = np.stack(list(map(fun, x)))
    self.assertAllClose(ans, expected_ans, check_dtypes=False)


    fun = lambda x: lax.pad(x, onp.float32(0), [(1, 2, 1), (0, 1, 0)])
    x = R(5, 10, 3).astype(onp.float32)
    ans = vmap(fun)(x)
    expected_ans = np.stack(list(map(fun, x)))
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

  def testConcatenate(self):
    R = lambda *shape: onp.random.RandomState(0).randn(*shape).astype(onp.float32)

    fun = lambda *args: lax.concatenate(args, dimension=0)
    x, y, z = R(10, 2, 3), R(1, 10, 3), R(4, 3)
    ans = vmap(fun, in_axes=(0, 1, None))(x, y, z)
    expected_ans = onp.concatenate([x, onp.swapaxes(y, 0, 1),
                                    onp.broadcast_to(z, (10, 4, 3))], 1)
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

    fun = lambda *args: lax.concatenate(args, dimension=1)
    x, y, z = R(10, 2, 1), R(2, 3), R(2, 4, 10)
    ans = vmap(fun, in_axes=(0, None, 2))(x, y, z)
    expected_ans = onp.concatenate([x, onp.broadcast_to(y, (10, 2, 3)),
                                    onp.moveaxis(z, 2, 0)], 2)
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

  def testJacobianIssue54(self):
    # test modeling the code in https://github.com/google/jax/issues/54

    def func(xs):
      return np.array([x for x in xs])

    xs = np.ones((5, 1))
    jacrev(func)(xs)  # don't crash
    jacfwd(func)(xs)  # don't crash

  def testAny(self):
    # test modeling the code in https://github.com/google/jax/issues/108

    ans = vmap(np.any)(np.array([[True, False], [False, False]]))
    expected = np.array([True, False])
    self.assertAllClose(ans, expected, check_dtypes=True)

  @jtu.skip_on_devices("tpu")
  def testHessian(self):
    # test based on code from sindhwani@google
    def fun(x, t):
      return np.sum(np.power(np.maximum(x, 0.0), 2)) + t

    x = onp.array([-1., -0.5, 0., 0.5, 1.0])

    ans = hessian(lambda x: fun(x, 0.0))(x)
    expected = onp.array([[0., 0., 0., 0., 0.],
                          [0., 0., 0., 0., 0.],
                          [0., 0.,0.5, 0., 0.],
                          [0., 0., 0., 2., 0.],
                          [0., 0., 0., 0., 2.]])
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testDynamicSlice(self):
    # test dynamic_slice via numpy indexing syntax
    x = onp.arange(30).reshape((10, 3))

    ans = vmap(lambda x, i: x[i], in_axes=(0, None))(x, 1)
    expected = x[:, 1]
    self.assertAllClose(ans, expected, check_dtypes=False)


    idx = onp.array([0, 1, 2, 1, 0] * 2)
    ans = vmap(lambda x, i: x[i], in_axes=(0, 0))(x, idx)
    expected = x[onp.arange(10), idx]
    self.assertAllClose(ans, expected, check_dtypes=False)

    x = onp.arange(3)
    idx = onp.array([0, 1, 2, 1, 0] * 2)
    ans = vmap(lambda x, i: x[i], in_axes=(None, 0))(x, idx)
    expected = x[idx]
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testRandom(self):
    seeds = vmap(random.PRNGKey)(onp.arange(10))
    ans = vmap(partial(random.normal, shape=(3, 2)))(seeds)
    expected = onp.stack([random.normal(random.PRNGKey(seed), (3, 2))
                          for seed in onp.arange(10)])
    self.assertAllClose(ans, expected, check_dtypes=False)
    assert len(onp.unique(ans)) == 10 * 3 * 2

  def testSortKeyVal(self):
    k = onp.arange(12)[::-1].reshape(3, 4)
    v = onp.random.RandomState(0).permutation(12).reshape(3, 4)

    sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 0))(k, v)
    self.assertAllClose(sk, k[:, ::-1], check_dtypes=True)
    self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)

    sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 1), 1)(k, v)
    self.assertAllClose(sk, k[::-1, :], check_dtypes=True)
    self.assertAllClose(sv, v[::-1, :], check_dtypes=True)

    sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 1))(k, v.T)
    self.assertAllClose(sk, k[:, ::-1], check_dtypes=True)
    self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)

    sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 0))(k.T, v)
    self.assertAllClose(sk, k[:, ::-1], check_dtypes=True)
    self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)

    sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (None, 0))(k[0], v)
    self.assertAllClose(sk, onp.broadcast_to(k[0, ::-1], (3, 4)),
                        check_dtypes=True)
    self.assertAllClose(sv, v[:, ::-1], check_dtypes=True)

    sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, None))(k.T, v[0])
    self.assertAllClose(sk, k[:, ::-1], check_dtypes=True)
    self.assertAllClose(sv, onp.broadcast_to(v[0, ::-1], (3, 4)),
                        check_dtypes=True)

  def testConvGeneralDilated(self):
    W = np.array(onp.random.randn(3, 3, 1, 5), dtype=onp.float32)
    X = np.array(onp.random.randn(10, 5, 5, 1), dtype=onp.float32)

    def f(params, x):
      one = (1, 1)
      dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
      y = lax.conv_general_dilated(
          x, params, one, 'SAME', one, one, dimension_numbers)
      return y
    grad_loss = grad(lambda params, x: np.mean(f(params, x) ** 2))

    # Test forward prop.
    per_example = vmap(partial(f, W))(np.reshape(X, (10, 1, 5, 5, 1)))
    per_example = np.reshape(per_example, (10, 5, 5, 5))
    per_example_direct = f(W, X)
    self.assertAllClose(per_example, per_example_direct, check_dtypes=True)

    # Test gradients.
    per_example = vmap(partial(grad_loss, W))(np.reshape(X, (10, 1, 5, 5, 1)))
    per_example_direct = []
    for i in range(10):
      g = grad_loss(W, np.reshape(X[i], (1, 5, 5, 1)))
      per_example_direct += [
          np.reshape(g, (1,) + g.shape)]
    per_example_direct = np.concatenate(per_example_direct, axis=0)
    self.assertAllClose(per_example, per_example_direct, check_dtypes=True)

  def testMaxPool(self):
    W = np.array(onp.random.randn(3, 3, 1, 5), dtype=onp.float32)
    X = np.array(onp.random.randn(10, 5, 5, 1), dtype=onp.float32)

    def f(params, x):
      one = (1, 1)
      dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
      y = lax.conv_general_dilated(
          x, params, one, 'SAME', one, one, dimension_numbers)
      y = lax.reduce_window(
          y, -np.inf, lax.max, (1, 2, 2, 1), (1, 1, 1, 1), 'SAME')
      return y
    grad_loss = grad(lambda params, x: np.mean(f(params, x) ** 2))

    # Test forward prop.
    per_example = vmap(partial(f, W))(np.reshape(X, (10, 1, 5, 5, 1)))
    per_example = np.reshape(per_example, (10, 5, 5, 5))
    per_example_direct = f(W, X)
    self.assertAllClose(per_example, per_example_direct, check_dtypes=True)

    # Test gradients.
    per_example = vmap(partial(grad_loss, W))(np.reshape(X, (10, 1, 5, 5, 1)))
    per_example_direct = []
    for i in range(10):
      g = grad_loss(W, np.reshape(X[i], (1, 5, 5, 1)))
      per_example_direct += [
          np.reshape(g, (1,) + g.shape)]
    per_example_direct = np.concatenate(per_example_direct, axis=0)
    self.assertAllClose(per_example, per_example_direct, check_dtypes=True)

  def testSumPool(self):
    W = np.array(onp.random.randn(3, 3, 1, 5), dtype=onp.float32)
    X = np.array(onp.random.randn(10, 5, 5, 1), dtype=onp.float32)

    def f(params, x):
      one = (1, 1)
      dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
      y = lax.conv_general_dilated(
          x, params, one, 'SAME', one, one, dimension_numbers)
      y = lax.reduce_window(
          y, 0.0, lax.add, (1, 2, 2, 1), (1, 1, 1, 1), 'SAME')
      return y
    grad_loss = grad(lambda params, x: np.mean(f(params, x) ** 2))

    # Test forward prop.
    per_example = vmap(partial(f, W))(np.reshape(X, (10, 1, 5, 5, 1)))
    per_example = np.reshape(per_example, (10, 5, 5, 5))
    per_example_direct = f(W, X)
    self.assertAllClose(per_example, per_example_direct, check_dtypes=True)

    # Test gradients.
    per_example = vmap(partial(grad_loss, W))(np.reshape(X, (10, 1, 5, 5, 1)))
    per_example_direct = []
    for i in range(10):
      g = grad_loss(W, np.reshape(X[i], (1, 5, 5, 1)))
      per_example_direct += [
          np.reshape(g, (1,) + g.shape)]
    per_example_direct = np.concatenate(per_example_direct, axis=0)
    self.assertAllClose(per_example, per_example_direct, check_dtypes=True)

  def testSelect(self):
    pred = onp.array([True, False])
    on_true = onp.array([0, 1])
    on_false = onp.array([2, 3])
    ans = vmap(lax.select)(pred, on_true, on_false)
    expected = onp.array([0, 3])
    self.assertAllClose(ans, expected, check_dtypes=True)

    pred = onp.array([False, True])
    on_true = onp.array([0, 1])
    on_false = onp.array([2, 3])
    ans = vmap(lax.select, (0, None, None))(pred, on_true, on_false)
    expected = onp.array([[2, 3],
                          [0, 1]])
    self.assertAllClose(ans, expected, check_dtypes=True)

    pred = True
    on_true = onp.array([0, 1], onp.float32)
    on_false = onp.array(3, onp.float32)
    ans = vmap(lax.select, (None, 0, None))(pred, on_true, on_false)
    expected = onp.array([0, 1], onp.float32)
    self.assertAllClose(ans, expected, check_dtypes=True)

    pred = onp.array([False, True])
    on_true = onp.array([0, 1], onp.float32)
    on_false = onp.array(3, onp.float32)
    ans = vmap(lax.select, (0, 0, None))(pred, on_true, on_false)
    expected = onp.array([3, 1], onp.float32)
    self.assertAllClose(ans, expected, check_dtypes=True)

    pred = onp.array([False, True])
    on_true = onp.array([2], onp.float32)
    on_false = onp.array([[3, 4]], onp.float32)
    ans = vmap(lax.select, (0, None, 1), 1)(pred, on_true, on_false)
    expected = onp.array([[3, 2]], onp.float32)
    self.assertAllClose(ans, expected, check_dtypes=True)

  def testLaxLinalgCholesky(self):
    a = onp.random.RandomState(0).randn(10, 5, 5).astype(onp.float32)
    a = onp.matmul(a, onp.conj(onp.swapaxes(a, -1, -2)))

    ans = vmap(lax_linalg.cholesky)(a)
    expected = onp.linalg.cholesky(a)
    self.assertAllClose(ans, expected, check_dtypes=False)

    b = onp.random.RandomState(0).randn(10, 5, 5).astype(onp.float32)
    b = onp.matmul(b, onp.conj(onp.swapaxes(b, -1, -2)))
    b_trans = onp.swapaxes(b, 0, 1)  # shape is (5, 10, 5)

    ans = vmap(lax_linalg.cholesky, in_axes=1, out_axes=0)(b_trans)
    expected = onp.linalg.cholesky(b)
    self.assertAllClose(ans, expected, check_dtypes=False)

  @parameterized.named_parameters(
      {"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
          slice_sizes),
       "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
       "slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx}
      for dtype in [onp.float32, onp.int32]
      for axis, shape, idxs, dnums, slice_sizes in [
          (0, (3, 5), onp.array([0, 2]), lax.GatherDimensionNumbers(
            offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,),
            index_vector_dim=1), (1,)),
          (1, (10, 3), onp.array([0, 0, 0]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,),
            index_vector_dim=1), (2,)),
          (1, (10, 3, 5), onp.array([0, 2, 1]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,),
            index_vector_dim=1), (1, 3)),
          (2, (10, 5, 3), onp.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1),
            index_vector_dim=1), (1, 3)),
      ]
      for rng_idx in [jtu.rand_int(max(shape))]
      for rng in [jtu.rand_default()])
  def testGatherBatchedOperand(self, axis, shape, dtype, idxs, dnums,
                               slice_sizes, rng, rng_idx):
    fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
    operand = rng(shape, dtype)
    ans = vmap(fun, (axis, None))(operand, idxs)
    expected = onp.stack([fun(operand[(slice(None),) * axis + (i,)], idxs)
                          for i in range(operand.shape[axis])])
    self.assertAllClose(ans, expected, check_dtypes=False)

  @parameterized.named_parameters(
      {"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
          slice_sizes),
       "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
       "slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx}
      for dtype in [onp.float32, onp.float64]
      for axis, shape, idxs, dnums, slice_sizes in [
          (0, (3, 5), onp.array([0, 2]), lax.GatherDimensionNumbers(
            offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,),
            index_vector_dim=1), (1,)),
          (1, (10, 3), onp.array([0, 0, 0]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,),
            index_vector_dim=1), (2,)),
          (1, (10, 3, 5), onp.array([0, 2, 1]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,),
            index_vector_dim=1), (1, 3)),
          (2, (10, 5, 3), onp.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1),
            index_vector_dim=1), (1, 3)),
      ]
      for rng_idx in [jtu.rand_int(max(shape))]
      for rng in [jtu.rand_default()])
  def testGatherGradBatchedOperand(self, axis, shape, dtype, idxs, dnums,
                                   slice_sizes, rng, rng_idx):
    fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
    gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx))))
    operand = rng(shape, dtype)
    ans = vmap(gfun, (axis, None))(operand, idxs)
    expected = onp.stack([gfun(operand[(slice(None),) * axis + (i,)], idxs)
                          for i in range(operand.shape[axis])])
    self.assertAllClose(ans, expected, check_dtypes=False)

  @parameterized.named_parameters(
      {"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
          slice_sizes),
       "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
       "slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx}
      for dtype in [onp.float32, onp.int32]
      for axis, shape, idxs, dnums, slice_sizes in [
          (0, (5,), onp.array([[0, 2], [1, 3]]), lax.GatherDimensionNumbers(
              offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,),
              index_vector_dim=1), (1,)),
          (1, (10,), onp.array([[0, 0, 0], [0, 2, 1]]).T,
           lax.GatherDimensionNumbers(
               offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,),
               index_vector_dim=1), (2,)),
          (1, (10, 5), onp.array([[0, 2, 1], [0, 3, 3]]).T,
           lax.GatherDimensionNumbers(
               offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,),
               index_vector_dim=1), (1, 3)),
          (0, (10, 5), onp.array([[[0, 2], [1, 0]],
                                  [[1, 2], [0, 3]]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1),
            index_vector_dim=1), (1, 3)),
      ]
      for rng_idx in [jtu.rand_int(max(shape))]
      for rng in [jtu.rand_default()])
  def testGatherBatchedIndices(self, axis, shape, dtype, idxs, dnums,
                               slice_sizes, rng, rng_idx):
    fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
    operand = rng(shape, dtype)
    ans = vmap(fun, (None, axis))(operand, idxs)
    expected = onp.stack([fun(operand, idxs[(slice(None),) * axis + (i,)])
                          for i in range(idxs.shape[axis])])
    self.assertAllClose(ans, expected, check_dtypes=False)

  @parameterized.named_parameters(
      {"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums,
          slice_sizes),
       "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
       "slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx}
      for dtype in [onp.float32, onp.float64]
      for axis, shape, idxs, dnums, slice_sizes in [
          (0, (5,), onp.array([[0, 2], [1, 3]]), lax.GatherDimensionNumbers(
              offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,),
              index_vector_dim=1), (1,)),
          (1, (10,), onp.array([[0, 0, 0], [0, 2, 1]]).T,
           lax.GatherDimensionNumbers(
               offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,),
               index_vector_dim=1), (2,)),
          (1, (10, 5), onp.array([[0, 2, 1], [0, 3, 3]]).T,
           lax.GatherDimensionNumbers(
               offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,),
               index_vector_dim=1), (1, 3)),
          (0, (10, 5), onp.array([[[0, 2], [1, 0]],
                                  [[1, 2], [0, 3]]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1),
            index_vector_dim=1), (1, 3)),
      ]
      for rng_idx in [jtu.rand_int(max(shape))]
      for rng in [jtu.rand_default()])
  def testGatherGradBatchedIndices(self, axis, shape, dtype, idxs, dnums,
                                   slice_sizes, rng, rng_idx):
    fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
    gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx))))
    operand = rng(shape, dtype)
    ans = vmap(gfun, (None, axis))(operand, idxs)
    expected = onp.stack([gfun(operand, idxs[(slice(None),) * axis + (i,)])
                          for i in range(idxs.shape[axis])])
    self.assertAllClose(ans, expected, check_dtypes=False)

  @parameterized.named_parameters(
      {"testcase_name": "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), op_axis, idxs_axis, idxs,
          dnums, slice_sizes),
       "op_axis": op_axis, "idxs_axis": idxs_axis, "shape": shape, "dtype":
       dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes,
       "rng": rng, "rng_idx": rng_idx}
      for dtype in [onp.float32, onp.int32]
      for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [
          (0, 0, (2, 5), onp.array([[0, 2], [1, 3]]), lax.GatherDimensionNumbers(
              offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,),
              index_vector_dim=1), (1,)),
          (1, 1, (10, 2), onp.array([[0, 0, 0], [0, 2, 1]]).T,
           lax.GatherDimensionNumbers(
               offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,),
               index_vector_dim=1), (2,)),
          (0, 1, (2, 10, 5,), onp.array([[0, 2, 1], [0, 3, 3]]).T,
           lax.GatherDimensionNumbers(
               offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,),
               index_vector_dim=1), (1, 3)),
          (2, 0, (10, 5, 2), onp.array([[[0, 2], [1, 0]],
                                        [[1, 0], [2, 0]]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1),
            index_vector_dim=1), (1, 3)),
      ]
      for rng_idx in [jtu.rand_int(max(shape))]
      for rng in [jtu.rand_default()])
  def testGatherBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums,
                            slice_sizes, rng, rng_idx):
    fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
    operand = rng(shape, dtype)
    assert operand.shape[op_axis] == idxs.shape[idxs_axis]
    ans = vmap(fun, (op_axis, idxs_axis))(operand, idxs)
    expected = onp.stack([fun(operand[(slice(None),) * op_axis + (i,)],
                              idxs[(slice(None),) * idxs_axis + (i,)])
                          for i in range(idxs.shape[idxs_axis])])
    self.assertAllClose(ans, expected, check_dtypes=False)

  @parameterized.named_parameters(
      {"testcase_name": "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), op_axis, idxs_axis, idxs,
          dnums, slice_sizes),
       "op_axis": op_axis, "idxs_axis": idxs_axis, "shape": shape, "dtype":
       dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes,
       "rng": rng, "rng_idx": rng_idx}
      for dtype in [onp.float32, onp.int32]
      for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [
          (0, 0, (2, 5), onp.array([[0, 2], [1, 3]]), lax.GatherDimensionNumbers(
              offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,),
              index_vector_dim=1), (1,)),
          (1, 1, (10, 2), onp.array([[0, 0, 0], [0, 2, 1]]).T,
           lax.GatherDimensionNumbers(
               offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,),
               index_vector_dim=1), (2,)),
          (0, 1, (2, 10, 5,), onp.array([[0, 2, 1], [0, 3, 3]]).T,
           lax.GatherDimensionNumbers(
               offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,),
               index_vector_dim=1), (1, 3)),
          (2, 0, (10, 5, 2), onp.array([[[0, 2], [1, 0]],
                                        [[1, 0], [2, 0]]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1),
            index_vector_dim=1), (1, 3)),
      ]
      for rng_idx in [jtu.rand_int(max(shape))]
      for rng in [jtu.rand_default()])
  def testGatherGradBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums,
                            slice_sizes, rng, rng_idx):
    fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
    gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx))))
    operand = rng(shape, dtype)
    assert operand.shape[op_axis] == idxs.shape[idxs_axis]
    ans = vmap(gfun, (op_axis, idxs_axis))(operand, idxs)
    expected = onp.stack([gfun(operand[(slice(None),) * op_axis + (i,)],
                              idxs[(slice(None),) * idxs_axis + (i,)])
                          for i in range(idxs.shape[idxs_axis])])
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testNumpyIndexing1(self):
    a = np.arange(2 * 3 * 4).reshape((2, 3, 4))
    ind = onp.array([[0, 1],
                    [2, 0]])
    def f(a, ind):
      return a[:, ind]
    expected = onp.stack([f(a, ind[i, :]) for i in range(ind.shape[0])])
    ans = vmap(f, (None, 0))(a, ind)
    assert onp.all(ans == expected)

  def testNumpyIndexing2(self):
    a = np.arange(2 * 3 * 4).reshape((2, 3, 4))
    def f(a):
      inds = np.array([0, 2])
      return a[:, inds]
    ans = vmap(f)(a)
    expected = onp.stack([f(a[:, i, :]) for i in range(a.shape[1])], axis=1)
    assert onp.all(ans == expected)

  def testTranspose(self):
    x = onp.arange(4 * 3 * 3).reshape((4, 3, 3))
    ans = vmap(lambda x: x + x.T)(x)
    expected = x + onp.swapaxes(x, -1, -2)
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testTransposePermutation(self):
    x = onp.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5))
    ans = vmap(lambda x: np.transpose(x, (1, 0, 2)))(x)
    expected = onp.transpose(x, (0, 2, 1, 3))
    self.assertAllClose(ans, expected, check_dtypes=False)

    x = onp.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5))
    ans = vmap(lambda x: np.transpose(x, (1, 2, 0)))(x)
    expected = onp.transpose(x, (0, 2, 3, 1))
    self.assertAllClose(ans, expected, check_dtypes=False)

    x = onp.arange(6 * 3 * 4 * 5).reshape((3, 4, 6, 5))
    ans = vmap(lambda x: np.transpose(x, (1, 2, 0)), in_axes=2)(x)
    expected = onp.transpose(x, (2, 1, 3, 0))
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testIssue354(self):
    psd_mat = onp.random.randn(20, 10)
    psd_mat = psd_mat.T.dot(psd_mat)
    vec = onp.random.randn(10)

    def f(scale):
      scaled_mat = scale * psd_mat
      chol = np.linalg.cholesky(scaled_mat)
      return -0.5 * np.sum((np.einsum('ij,j->i', chol, vec))**2)
    vmapped_f = vmap(f)
    vmapped_f_grad = grad(lambda x: np.sum(vmapped_f(x)))

    scales = onp.array([[0.1], [0.2], [0.3], [0.4], [0.5]])
    ans = vmapped_f_grad(scales)  # don't crash!
    expected = onp.stack([grad(f)(scale) for scale in scales])
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testIssue387(self):
    # https://github.com/google/jax/issues/387
    R = onp.random.RandomState(0).rand(100, 2)

    def dist_sq(R):
      dR = R[:, np.newaxis, :] - R[np.newaxis, :, :]
      zero = np.zeros_like(dR)
      dR = dR - np.where(np.abs(dR) < 0.5, zero, 0.5 * np.sign(dR))
      return np.sum(dR ** 2, axis=2)

    @jit
    def f(R):
      dr = dist_sq(R)
      return np.sum(R ** 2)

    H = hessian(f)(R)  # don't crash on UnshapedArray
Beispiel #30
0
 def testSlice(self, shape, dtype, starts, limits, strides, bdims):
     rng = jtu.rand_default(self.rng())
     op = lambda x: lax.slice(x, starts, limits, strides)
     self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng)