Ejemplo n.º 1
0
class TestBFGS(jtu.JaxTestCase):

  @parameterized.named_parameters(jtu.cases_from_list(
    {"testcase_name": "_func={}_maxiter={}".format(func_and_init[0].__name__, maxiter),
     "maxiter": maxiter, "func_and_init": func_and_init}
    for maxiter in [None]
    for func_and_init in [(rosenbrock, np.zeros(2)),
                          (himmelblau, np.zeros(2)),
                          (matyas, np.ones(2) * 6.),
                          (eggholder, np.ones(2) * 100.)]))
  def test_minimize(self, maxiter, func_and_init):
    # Note, cannot compare step for step with scipy BFGS because our line search is _slightly_ different.

    func, x0 = func_and_init

    @jit
    def min_op(x0):
      result = jax.scipy.optimize.minimize(
          func(jnp),
          x0,
          method='BFGS',
          options=dict(maxiter=maxiter, gtol=1e-6),
      )
      return result.x

    jax_res = min_op(x0)
    scipy_res = scipy.optimize.minimize(func(np), x0, method='BFGS').x
    self.assertAllClose(scipy_res, jax_res, atol=2e-5, check_dtypes=False)

  def test_fixes4594(self):
    n = 2
    A = jnp.eye(n) * 1e4
    def f(x):
      return jnp.mean((A @ x) ** 2)
    results = jax.scipy.optimize.minimize(f, jnp.ones(n), method='BFGS')
    self.assertAllClose(results.x, jnp.zeros(n), atol=1e-6, rtol=1e-6)

  @jtu.skip_on_flag('jax_enable_x64', False)
  def test_zakharov(self):
    def zakharov_fn(x):
      ii = jnp.arange(1, len(x) + 1, step=1)
      answer = zakharovFromIndices(x=x, ii=ii)
      return answer

    x0 = jnp.array([600.0, 700.0, 200.0, 100.0, 90.0, 1e4])
    eval_func = jax.jit(zakharov_fn)
    jax_res = jax.scipy.optimize.minimize(fun=eval_func, x0=x0, method='BFGS')
    self.assertLess(jax_res.fun, 1e-6)

  def test_minimize_bad_initial_values(self):
    # This test runs deliberately "bad" initial values to test that handling
    # of failed line search, etc. is the same across implementations
    initial_value = jnp.array([92, 0.001])
    opt_fn = himmelblau(jnp)
    jax_res = jax.scipy.optimize.minimize(
        fun=opt_fn,
        x0=initial_value,
        method='BFGS',
    ).x
    scipy_res = scipy.optimize.minimize(
        fun=opt_fn,
        jac=jax.grad(opt_fn),
        method='BFGS',
        x0=initial_value
    ).x
    self.assertAllClose(scipy_res, jax_res, atol=2e-5, check_dtypes=False)


  def test_args_must_be_tuple(self):
    A = jnp.eye(2) * 1e4
    def f(x):
      return jnp.mean((A @ x) ** 2)
    with self.assertRaisesRegex(TypeError, "args .* must be a tuple"):
      jax.scipy.optimize.minimize(f, jnp.ones(2), args=45, method='BFGS')
Ejemplo n.º 2
0
class LaxBackedScipySignalTests(jtu.JaxTestCase):
    """Tests for LAX-backed scipy.stats implementations"""
    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_op={}_xshape={}_yshape={}_mode={}".format(
                    op, jtu.format_shape_dtype_string(xshape, dtype),
                    jtu.format_shape_dtype_string(yshape, dtype), mode),
                "xshape":
                xshape,
                "yshape":
                yshape,
                "dtype":
                dtype,
                "mode":
                mode,
                "jsp_op":
                getattr(jsp_signal, op),
                "osp_op":
                getattr(osp_signal, op)
            } for mode in ['full', 'same', 'valid']
            for op in ['convolve', 'correlate'] for dtype in default_dtypes
            for shapeset in [onedim_shapes, twodim_shapes, threedim_shapes]
            for xshape in shapeset for yshape in shapeset))
    def testConvolutions(self, xshape, yshape, dtype, mode, jsp_op, osp_op):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
        osp_fun = partial(osp_op, mode=mode)
        jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST)
        tol = {np.float16: 1e-2, np.float32: 1e-2, np.float64: 1e-8}
        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=tol)
        self._CompileAndCheck(jsp_fun, args_maker)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "op={}_xshape={}_yshape={}_mode={}".format(
                op, jtu.format_shape_dtype_string(xshape, dtype),
                jtu.format_shape_dtype_string(yshape, dtype), mode),
            "xshape":
            xshape,
            "yshape":
            yshape,
            "dtype":
            dtype,
            "mode":
            mode,
            "jsp_op":
            getattr(jsp_signal, op),
            "osp_op":
            getattr(osp_signal, op)
        } for mode in ['full', 'same', 'valid']
                            for op in ['convolve2d', 'correlate2d']
                            for dtype in default_dtypes
                            for xshape in twodim_shapes
                            for yshape in twodim_shapes))
    def testConvolutions2D(self, xshape, yshape, dtype, mode, jsp_op, osp_op):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
        osp_fun = partial(osp_op, mode=mode)
        jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST)
        tol = {np.float16: 1e-2, np.float32: 1e-2, np.float64: 1e-14}
        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=tol)
        self._CompileAndCheck(jsp_fun, args_maker)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_axis={}_type={}_bp={}".format(
                jtu.format_shape_dtype_string(shape, dtype), axis, type, bp),
            "shape":
            shape,
            "dtype":
            dtype,
            "axis":
            axis,
            "type":
            type,
            "bp":
            bp
        } for shape in [(5, ), (4, 5), (3, 4, 5)] for dtype in default_dtypes
                            for axis in [0, -1]
                            for type in ['constant', 'linear']
                            for bp in [0, [0, 2]]))
    def testDetrend(self, shape, dtype, axis, type, bp):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(shape, dtype)]
        osp_fun = partial(osp_signal.detrend, axis=axis, type=type, bp=bp)
        jsp_fun = partial(jsp_signal.detrend, axis=axis, type=type, bp=bp)
        tol = {np.float32: 1e-5, np.float64: 1e-12}
        self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
Ejemplo n.º 3
0
class LaxAutodiffTest(jtu.JaxTestCase):

  @parameterized.named_parameters(itertools.chain.from_iterable(
      jtu.cases_from_list(
        {"testcase_name": jtu.format_test_name_suffix(
            rec.name, shapes, itertools.repeat(dtype)),
         "op": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype,
         "order": rec.order, "tol": rec.tol}
        for shape_group in compatible_shapes
        for shapes in itertools.combinations_with_replacement(shape_group, rec.nargs)
        for dtype in rec.dtypes)
      for rec in LAX_GRAD_OPS))
  def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol):
    rng = rng_factory(self.rng())
    if jtu.device_under_test() == "tpu" and op is lax.pow:
      raise SkipTest("pow grad imprecise on tpu")
    tol = jtu.join_tolerance(1e-1, tol) if jtu.num_float_bits(dtype) == 32 else tol
    args = tuple(rng(shape, dtype) for shape in shapes)
    check_grads(op, args, order, ["fwd", "rev"], tol, tol)

  @parameterized.named_parameters(itertools.chain.from_iterable(
      jtu.cases_from_list(
          {"testcase_name": "_{}_{}".format(rec.op.__name__, special_value),
           "op": rec.op, "special_value": special_value, "tol": rec.tol}
          for special_value in rec.values)
      for rec in LAX_GRAD_SPECIAL_VALUE_TESTS))
  def testOpGradSpecialValue(self, op, special_value, tol):
    check_grads(op, (special_value,), 2, ["fwd", "rev"], rtol=tol, atol=tol)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_from_dtype={}_to_dtype={}".format(
          jtu.dtype_str(from_dtype), jtu.dtype_str(to_dtype)),
       "from_dtype": from_dtype, "to_dtype": to_dtype}
      for from_dtype, to_dtype in itertools.product(inexact_dtypes, repeat=2)))
  def testConvertElementTypeGrad(self, from_dtype, to_dtype):
    rng = jtu.rand_default(self.rng())
    tol = max(jtu.tolerance(to_dtype, jtu.default_gradient_tolerance),
              jtu.tolerance(from_dtype, jtu.default_gradient_tolerance))
    args = (rng((2, 3), from_dtype),)
    convert_element_type = lambda x: lax.convert_element_type(x, to_dtype)
    convert_element_type = jtu.ignore_warning(category=np.ComplexWarning)(
      convert_element_type)
    check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}".format(
          jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype}
      for shape in [(), (2, 3)]
      for dtype in grad_float_dtypes))
  def testClampGrad(self, shape, dtype):
    rng = jtu.rand_default(self.rng())
    operand = rng(shape, dtype)
    low = operand - dtype(10)
    high = operand + dtype(10)
    # Avoids points near the boundary where the gradient may be inaccurate.
    check_grads(lax.clamp, (operand, low, high), 2, ["fwd", "rev"], eps=1e-2)
    check_grads(lax.clamp, (low, operand, high), 2, ["fwd", "rev"], eps=1e-2)
    check_grads(lax.clamp, (low, high, operand), 2, ["fwd", "rev"], eps=1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format(
          dim, ",".join(str(d) for d in base_shape), np.dtype(dtype).name,
          num_arrs),
       "dim": dim, "base_shape": base_shape, "dtype": dtype, "num_arrs": num_arrs}
      for num_arrs in [3]
      for dtype in float_dtypes
      for base_shape in [(4,), (3, 4), (2, 3, 4)]
      for dim in range(len(base_shape))))
  def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs):
    rng = jtu.rand_default(self.rng())
    shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:]
              for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))]
    operands = tuple(rng(shape, dtype) for shape in shapes)
    concatenate = lambda *args: lax.concatenate(args, dim)
    check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_strides={}_padding={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               strides, padding),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "strides": strides, "padding": padding}
       for lhs_shape, rhs_shape, all_strides in itertools.chain(
           [((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)])
            for b, i, j in itertools.product([2, 3], repeat=3)],
           [((4, 2, 1), (3, 2, 1), [(1,)])])
       for strides in all_strides
       for dtype in float_dtypes
       for padding in ["VALID", "SAME"]))
  def testConvGrad(self, lhs_shape, rhs_shape, dtype, strides, padding):
    rng = jtu.rand_small(self.rng())
    lhs = rng(lhs_shape, dtype)
    rhs = rng(rhs_shape, dtype)
    conv = partial(lax.conv, window_strides=strides, padding=padding,
                   precision=lax.Precision.HIGHEST)
    check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
                         atol=1e-2, rtol=1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
       "rhs_dilation={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               strides, padding, lhs_dil, rhs_dil),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "strides": strides, "padding": padding, "lhs_dil": lhs_dil,
       "rhs_dil": rhs_dil}
       for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in
       itertools.chain(
           [((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)],
             [((0, 0), (0, 0)), ((-1, 0), (0, -1)), ((1, 0), (0, 1))],
             [(1, 1), (2, 1)], [(1, 1)])
            for b, i, j in itertools.product([2, 3], repeat=3)],
           [((4, 2, 1), (3, 2, 1), [(1,)], [((1, 1),), ((0, 0),)],
             [(1,), (2,)], [(1,), (2,)])])
       for strides in all_strides
       for rhs_dil in rhs_dils
       for lhs_dil in lhs_dils
       for dtype in float_dtypes
       for padding in all_pads))
  def testConvWithGeneralPaddingGrad(self, lhs_shape, rhs_shape, dtype, strides,
                                     padding, lhs_dil, rhs_dil):
    rng = jtu.rand_small(self.rng())
    lhs = rng(lhs_shape, dtype)
    rhs = rng(rhs_shape, dtype)
    conv = partial(lax.conv_with_general_padding, window_strides=strides,
                   padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil,
                   precision=lax.Precision.HIGHEST)
    check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
                         atol=1e-2, rtol=1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
       "rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums),
               feature_group_count, batch_group_count),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "strides": strides, "padding": padding, "lhs_dil": lhs_dil,
       "rhs_dil": rhs_dil, "dimension_numbers": dim_nums,
       "perms": perms, "feature_group_count": feature_group_count,
       "batch_group_count": batch_group_count}
      for batch_group_count, feature_group_count in ([(1, 1), (2, 1), (1, 2)])
      for lhs_shapes, rhs_shape, all_strides, lhs_dils, rhs_dils in [
          ([(b * batch_group_count, i * feature_group_count, 6, 7),
            (b * batch_group_count, i * feature_group_count, 0, 4)],  # lhs_shape
           (j * batch_group_count * feature_group_count, i, 1, 2),  # rhs_shape
           [(1, 1), (1, 2), (2, 1)],  # strides
           [(1, 1), (2, 1)],  # lhs_dils
           [(1, 1), (2, 2)])  # rhs_dils
          for b, i, j in itertools.product([1, 2], repeat=3)]
      for lhs_shape in lhs_shapes
      for strides in all_strides
      for rhs_dil in rhs_dils
      for lhs_dil in lhs_dils
      for dtype in grad_inexact_dtypes
      for padding in ([((0, 0), (0, 0)), ((1, 0), (0, 1))] +
        ([((0, -1), (0, 0))] if lhs_shape[2] != 0 else []))
      for dim_nums, perms in [
          (("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])),
          (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])),
          (("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))]))
  def testConvGeneralDilatedGrad(self, lhs_shape, rhs_shape, dtype, strides,
                                 padding, lhs_dil, rhs_dil, dimension_numbers,
                                 perms, feature_group_count, batch_group_count):
    if dtype == np.float16:
      raise SkipTest("float16 numerical issues")  # TODO(mattjj): resolve

    if (jtu.device_under_test() == "cpu" and dtype == np.float64 and
        lhs_shape == (1,1,6,7) and rhs_shape == (2,1,1,2) and strides == (2, 1)
        and padding == ((0, -1), (0, 0)) and lhs_dil == (1, 1) and
        rhs_dil == (1, 1)):
      # TODO(b/173608403): reenable after LLVM fix.
      raise SkipTest("Skipping test due to LLVM lowering bug")
    rng = jtu.rand_default(self.rng())
    tol = {dtypes.bfloat16: 1e-0, np.float16: 5e-1, np.float32: 1e-3}

    # permute shapes to match dim_spec, scale by feature_group_count
    lhs_perm, rhs_perm = perms
    lhs_shape = list(np.take(lhs_shape, lhs_perm))
    rhs_shape = list(np.take(rhs_shape, rhs_perm))

    lhs = rng(lhs_shape, dtype)
    rhs = rng(rhs_shape, dtype)
    conv = partial(lax.conv_general_dilated, window_strides=strides,
                   padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil,
                   dimension_numbers=dimension_numbers,
                   feature_group_count=feature_group_count,
                   batch_group_count=batch_group_count,
                   precision=lax.Precision.HIGHEST)
    check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
                         atol=tol, rtol=tol)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_lhs_shape={}_rhs_shape={}".format(
          jtu.format_shape_dtype_string(lhs_shape, dtype),
          jtu.format_shape_dtype_string(rhs_shape, dtype)),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype}
      for lhs_shape in [(2,), (3, 2)] for rhs_shape in [(2,), (2, 4)]
      for dtype in float_dtypes))
  def testDotGrad(self, lhs_shape, rhs_shape, dtype):
    rng = jtu.rand_default(self.rng())
    tol = {np.float16: 1e-1, np.float32: 1e-4}
    lhs = rng(lhs_shape, dtype)
    rhs = rng(rhs_shape, dtype)
    dot = partial(lax.dot, precision=lax.Precision.HIGHEST)
    check_grads_bilinear(dot, (lhs, rhs), order=2, modes=["fwd", "rev"],
                         atol=tol, rtol=tol)
    # check that precision config is preserved
    result, pullback = api.vjp(dot, lhs, rhs)
    gresult = lax.zeros_like_array(result)
    s = str(api.make_jaxpr(pullback)(gresult))
    assert "Precision.HIGHEST" in s

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_dimension_numbers={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               dimension_numbers),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "dimension_numbers": dimension_numbers}
      for lhs_shape, rhs_shape, dimension_numbers in [
          ((3, 2), (2, 4), (([1], [0]), ([], []))),
          ((3, 5), (2, 5), (([1], [1]), ([], []))),
          ((5, 3), (5, 2), (([0], [0]), ([], []))),
          ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
          ((3, 5, 2), (2, 4, 5), (([2], [0]), ([1], [2]))),
          ((7, 3, 5, 2), (2, 2, 4, 5), (([3], [0]), ([2], [3]))),
      ]
      for dtype in float_dtypes))
  def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype,
                                          dimension_numbers):
    rng = jtu.rand_small(self.rng())
    lhs = rng(lhs_shape, dtype)
    rhs = rng(rhs_shape, dtype)
    dot_general = partial(lax.dot_general, dimension_numbers=dimension_numbers,
                          precision=lax.Precision.HIGHEST)
    check_grads_bilinear(dot_general, (lhs, rhs), order=2, modes=["fwd", "rev"])
    # check that precision config is preserved
    result, pullback = api.vjp(dot_general, lhs, rhs)
    gresult = lax.zeros_like_array(result)
    s = str(api.make_jaxpr(pullback)(gresult))
    assert "Precision.HIGHEST" in s

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format(
          shape, np.dtype(dtype).name, broadcast_sizes),
       "shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes}
      for shape in [(), (2, 3)]
      for dtype in float_dtypes
      for broadcast_sizes in [(), (2,), (1, 2)]))
  def testBroadcastGrad(self, shape, dtype, broadcast_sizes):
    rng = jtu.rand_default(self.rng())
    args = (rng(shape, dtype),)
    broadcast = lambda x: lax.broadcast(x, broadcast_sizes)
    check_grads(broadcast, args, 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_outshape={}_bcdims={}".format(
          jtu.format_shape_dtype_string(inshape, dtype),
          outshape, broadcast_dimensions),
       "inshape": inshape, "dtype": dtype, "outshape": outshape,
       "dimensions": broadcast_dimensions}
      for inshape, outshape, broadcast_dimensions in [
          ([2], [2, 2], [0]),
          ([2], [2, 2], [1]),
          ([2], [2, 3], [0]),
          ([], [2, 3], []),
      ]
      for dtype in float_dtypes))
  def testBroadcastInDimGrad(self, inshape, dtype, outshape, dimensions):
    rng = jtu.rand_default(self.rng())
    operand = rng(inshape, dtype)
    broadcast_in_dim = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
    check_grads(broadcast_in_dim, (operand,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_outshape={}_perm={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          jtu.format_shape_dtype_string(out_shape, dtype),
          permutation),
       "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
       "permutation": permutation}
      for dtype in float_dtypes
      for arg_shape, out_shape, permutation in [
          [(3, 4), (12,), None],
          [(2, 1, 4), (8,), None],
          [(2, 2, 4), (2, 8), None],
          [(3, 4), (12,), (0, 1)],
          [(3, 4), (12,), (1, 0)],
          [(2, 1, 4), (8,), (0, 2, 1)],
          [(2, 1, 4), (8,), (2, 0, 1)],
          [(2, 2, 4), (2, 8), (0, 2, 1)],
          [(2, 2, 4), (2, 8), (2, 0, 1)],
      ]))
  def testReshapeGrad(self, arg_shape, out_shape, permutation, dtype):
    rng = jtu.rand_default(self.rng())
    operand = rng(arg_shape, dtype)
    reshape = lambda x: lax.reshape(x, out_shape, permutation)
    check_grads(reshape, (operand,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_pads={}"
       .format(jtu.format_shape_dtype_string(shape, dtype), pads),
       "shape": shape, "dtype": dtype, "pads": pads}
      for shape in [(2, 3)]
      for dtype in float_dtypes
      for pads in [[(1, 2, 1), (0, 1, 0)], [(-1, 0, 0), (-1, 0, 2)]]))
  def testPadGrad(self, shape, dtype, pads):
    rng = jtu.rand_small(self.rng())
    operand = rng(shape, dtype)
    pad = lambda operand: lax.pad(operand, np.array(0, dtype), pads)
    check_grads(pad, (operand,), 2, ["fwd", "rev"], eps=1.)

    operand = rng(shape, dtype)
    padding_value = np.array(0., dtype)
    pad = lambda operand, padding_value: lax.pad(operand, padding_value, pads)
    check_grads(pad, (operand, padding_value), 2, ["fwd", "rev"], eps=1.)

  def testReverseGrad(self):
    rev = lambda operand: lax.rev(operand, dimensions)

    dimensions = [0]
    check_grads(rev, (np.array([3., 2., 1.]),), 2)

    dimensions = [0, 1]
    check_grads(rev, (np.array([[6., 5., 4.], [3., 2., 1.]]),), 2,
                rtol={np.float32: 3e-3})

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_predshape={}_argshapes={}".format(
          jtu.format_shape_dtype_string(pred_shape, np.bool_),
          jtu.format_shape_dtype_string(arg_shape, dtype)),
       "pred_shape": pred_shape, "arg_shape": arg_shape, "dtype": dtype}
      for arg_shape in [(), (3,), (2, 3)]
      for pred_shape in ([(), arg_shape] if arg_shape else [()])
      for dtype in float_dtypes))
  def testSelectGrad(self, pred_shape, arg_shape, dtype):
    rng = jtu.rand_default(self.rng())
    pred = rng(pred_shape, np.bool_)
    on_true = rng(arg_shape, dtype)
    on_false = rng(arg_shape, dtype)
    select = lambda on_true, on_false: lax.select(pred, on_true, on_false)
    check_grads(select, (on_true, on_false), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}_start_indices={}_limit_indices={}_strides={}".format(
          jtu.format_shape_dtype_string(shape, dtype),
          start_indices, limit_indices, strides),
       "shape": shape, "dtype": dtype, "starts": start_indices,
       "limits": limit_indices, "strides": strides}
      for shape, start_indices, limit_indices, strides in [
        [(3,), (1,), (2,), None],
        [(7,), (4,), (7,), None],
        [(5,), (1,), (5,), (2,)],
        [(8,), (1,), (6,), (2,)],
        [(5, 3), (1, 1), (3, 2), None],
        [(5, 3), (1, 1), (3, 1), None],
        [(7, 5, 3), (4, 0, 1), (7, 1, 3), None],
        [(5, 3), (1, 1), (2, 1), (1, 1)],
        [(5, 3), (1, 1), (5, 3), (2, 1)],
        [(3, 3, 5), (0, 2, 0), (3, 2, 5), (1, 2, 1)]
      ]
      for dtype in float_dtypes))
  def testSliceGrad(self, shape, dtype, starts, limits, strides):
    rng = jtu.rand_default(self.rng())
    operand = rng(shape, dtype)
    slice = lambda x: lax.slice(x, starts, limits, strides)
    check_grads(slice, (operand,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_start_indices={}_size_indices={}".format(
          jtu.format_shape_dtype_string(shape, dtype),
          start_indices, size_indices),
       "shape": shape, "dtype": dtype, "start_indices": start_indices,
       "size_indices": size_indices}
      for shape, start_indices, size_indices in [
        [(3,), (1,), (1,)],
        [(5, 3), (1, 1), (3, 1)],
        [(7, 5, 3), (4, 1, 0), (2, 0, 1)],
      ]
      for dtype in float_dtypes))
  def testDynamicSliceGrad(self, shape, dtype, start_indices, size_indices):
    rng = jtu.rand_default(self.rng())
    operand = rng(shape, dtype)
    dynamic_slice = lambda x: lax.dynamic_slice(x, start_indices, size_indices)
    check_grads(dynamic_slice, (operand,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_start_indices={}_update_shape={}".format(
          jtu.format_shape_dtype_string(shape, dtype),
          start_indices, update_shape),
       "shape": shape, "dtype": dtype, "start_indices": start_indices,
       "update_shape": update_shape}
      for shape, start_indices, update_shape in [
        [(3,), (1,), (1,)],
        [(5, 3), (1, 1), (3, 1)],
        [(7, 5, 3), (4, 1, 0), (2, 0, 1)],
      ]
      for dtype in float_dtypes))
  def testDynamicUpdateSliceGrad(self, shape, dtype, start_indices, update_shape):
    rng = jtu.rand_default(self.rng())
    operand = rng(shape, dtype)
    update = rng(update_shape, dtype)
    start_indices = np.array(start_indices)

    dus = lambda x, y: lax.dynamic_update_slice(x, y, start_indices)
    check_grads(dus, (operand, update), 2, ["fwd", "rev"], eps=1.)

    dus = lambda x: lax.dynamic_update_slice(x, update, start_indices)
    check_grads(dus, (operand,), 2, ["fwd", "rev"], eps=1.)

    dus = lambda y: lax.dynamic_update_slice(operand, y, start_indices)
    check_grads(dus, (update,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_perm={}".format(
          jtu.format_shape_dtype_string(shape, dtype), perm),
       "shape": shape, "dtype": dtype, "perm": perm}
      for shape, perm in [
        [(3, 4), (1, 0)],
        [(3, 4), (0, 1)],
        [(3, 4, 5), (2, 1, 0)],
        [(3, 4, 5), (1, 0, 2)],
      ]
      for dtype in float_dtypes))
  def testTransposeGrad(self, shape, dtype, perm):
    rng = jtu.rand_default(self.rng())
    operand = rng(shape, dtype)
    transpose = lambda x: lax.transpose(x, perm)
    check_grads(transpose, (operand,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_op={}_inshape={}_reducedims={}"
       .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims),
       "op": op, "init_val": init_val, "shape": shape, "dtype": dtype,
       "dims": dims, "rng_factory": rng_factory}
      for init_val, op, dtypes, rng_factory in [
          (0, lax.add, float_dtypes + jtu.dtypes.complex, jtu.rand_default),
          (-np.inf, lax.max, grad_inexact_dtypes, jtu.rand_unique_int),
          (np.inf, lax.min, grad_inexact_dtypes, jtu.rand_unique_int),
          (1, lax.mul, grad_float_dtypes, partial(jtu.rand_default, scale=1)),
      ]
      for dtype in dtypes
      for shape, dims in [
          [(), ()],
          [(3, 4, 5), ()],
          [(3, 4, 5), (0,)],
          [(3, 4, 5), (1, 2)],
          [(3, 4, 5), (0, 2)],
          [(3, 4, 5), (0, 1, 2)],
          [(3, 1), (1,)],
          [(3, 0, 5), (1,)],
      ]))
  def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory):
    rng = rng_factory(self.rng())
    if jtu.device_under_test() == "tpu" and op is lax.mul:
      raise SkipTest("unimplemented case")
    tol = {dtypes.bfloat16: 2e-1, np.float16: 1e-1, np.float32: 1e-1,
           np.float64: 1e-3, np.complex64: 1e-1}
    operand = rng(shape, dtype)
    init_val = np.asarray(init_val, dtype=dtype)
    reduce = lambda operand: lax.reduce(operand, init_val, op, dims)
    eps = (1.0 if dtypes.finfo(dtype).bits == 16 and op is lax.add else
           1e-1 if dtype == dtypes.bfloat16 else
           1e-2 if dtypes.finfo(dtype).bits == 32 else None)
    if op not in (lax.max, lax.min) or all(d > 0 for d in shape):
      check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_reducedims={}"
       .format(jtu.format_shape_dtype_string(shape, dtype), dims),
       "shape": shape, "dtype": dtype, "dims": dims}
      for dtype in grad_float_dtypes
      for shape, dims in [
          [(3, 4, 5), ()],
          [(3, 4, 5), (0,)],
          [(3, 4, 5), (1, 2)],
          [(3, 4, 5), (0, 2)],
          [(3, 4, 5), (0, 1, 2)],
          [(3, 1), (1,)],
          [(3, 0, 5), (1,)],
      ]))
  def testReducePairGrad(self, shape, dtype, dims):
    rng = jtu.rand_default(self.rng(), scale=1)
    tol = {np.float32: 1e-2, np.float64: 1e-4}
    operands = (rng(shape, dtype), rng(shape, dtype))
    init_vals = (np.array(0, dtype), np.array(1, dtype))
    def op(xs, ys):
      return (xs[0] + ys[0], xs[1] * ys[1])
    reduce = lambda xs, ys: lax.reduce((xs, ys), init_vals, op, dims)
    check_grads(reduce, operands, 2, ["fwd", "rev"], tol, tol)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}"
                         "_basedilation={}_windowdilation={}")
       .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims,
               strides, padding, base_dilation, window_dilation),
       "op": op, "init_val": init_val, "dtype": dtype, "shape": shape,
       "dims": dims, "strides": strides, "padding": padding,
       "base_dilation": base_dilation, "window_dilation": window_dilation,
       "rng_factory": rng_factory}
      for init_val, op, dtypes, rng_factory in [
          (0, lax.add, grad_float_dtypes, jtu.rand_small),
          (-np.inf, lax.max, grad_float_dtypes, jtu.rand_unique_int),
          (np.inf, lax.min, grad_float_dtypes, jtu.rand_unique_int),
      ]
      for shape, dims, strides, padding, base_dilation, window_dilation in (
        itertools.chain(
          itertools.product(
            [(4, 6)],
            [(2, 1), (1, 2)],
            [(1, 1), (2, 1), (1, 2)],
            ["VALID", "SAME", [(0, 3), (1, 2)]],
            [(1, 1)] + ([(2, 3)] if op is lax.add else []),
            [(1, 1)] + ([(1, 2)] if op is lax.add else [])),
          itertools.product(
            [(3, 2, 4, 6)],
            [(1, 1, 2, 1), (2, 1, 2, 1)],
            [(1, 2, 2, 1), (1, 1, 1, 1)],
            ["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]],
            [(1, 1, 1, 1)] + ([(2, 1, 3, 2)] if op is lax.add else []),
            [(1, 1, 1, 1)] + ([(1, 2, 2, 1)] if op is lax.add else []))))
      for dtype in dtypes))
  @jtu.ignore_warning(category=UserWarning,
                      message="Using reduced precision for gradient.*")
  def testReduceWindowGrad(
      self, op, init_val, dtype, shape, dims, strides,
      padding, base_dilation, window_dilation, rng_factory):
    rng = rng_factory(self.rng())
    init_val = np.asarray(init_val, dtype=dtype)

    gradient_order = 3
    # We need this conditional and the corresponding loop logic to be in the
    # test method, rather than at the parameterized test level, because it
    # depends on FLAGS for the device under test.
    # TODO(b/31565929): enable when fixed.
    if jtu.device_under_test() == "tpu" and op is not lax.add:
      if (len(shape) != 4 or dims != (1, 1, 2, 1)
          or not isinstance(padding, str)):
        raise SkipTest("Only R4 SelectAndScatter implemented on TPU")

      # TODO(b/73062247): need variadic reduce-window for better precision.
      gradient_order = 1

    def fun(operand):
      return lax.reduce_window(operand, init_val, op, dims, strides, padding,
                               base_dilation, window_dilation)

    operand = rng(shape, dtype)
    if op is lax.add:
      eps = 1.
      tol = None
    else:
      # this test can fail if there are duplicates in operand
      self.assertEqual(np.unique(operand).size, operand.size,
                       msg="test requires operand elements to be unique.")
      eps = 1e-2
      tol = {np.float16: 1e-1, np.float32: 6e-2, np.float64: 6e-2}
    check_grads(fun, (operand,), gradient_order, ["fwd", "rev"], tol, tol,
                eps)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_op={}_shape={}_axis={}_reverse={}"
       .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis,
               reverse),
       "op": op, "shape": shape, "dtype": dtype,
       "axis": axis, "reverse": reverse}
      for op, types in [
          (lax.cumsum, [np.float32, np.float64]),
          (lax.cumprod, [np.float32, np.float64]),
      ]
      for dtype in types
      for shape in [[10], [3, 4, 5]]
      for axis in range(len(shape))
      for reverse in [False, True]))
  def testCumulativeReduceGrad(self, op, shape, dtype, axis, reverse):
    rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
                   else jtu.rand_small)
    rng = rng_factory(self.rng())
    check_grads(partial(op, axis=axis, reverse=reverse), (rng(shape, dtype),),
                order=2)


  # TODO(b/205052657): enable more tests when supported
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_axis={}_isstable={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, is_stable),
       "shape": shape, "dtype": dtype, "axis": axis, "is_stable": is_stable}
      for dtype in [np.float32]
      for shape in [(5,), (5, 7)]
      for axis in [len(shape) - 1]
      for is_stable in [False, True]))
  def testSortGrad(self, shape, dtype, axis, is_stable):
    rng = jtu.rand_default(self.rng())
    operand = rng(shape, dtype)
    sort = lambda x: lax.sort(x, dimension=axis, is_stable=is_stable)
    check_grads(sort, (operand,), 2, ["fwd", "rev"], eps=1e-2)

  # TODO(b/205052657): enable more tests when supported
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_keyshape={}_valshape={}_axis={}_isstable={}".format(
          jtu.format_shape_dtype_string(shape, key_dtype),
          jtu.format_shape_dtype_string(shape, val_dtype),
          axis, is_stable),
       "shape": shape, "key_dtype": key_dtype, "val_dtype": val_dtype,
       "axis": axis, "is_stable": is_stable}
      for key_dtype in [np.float32]
      for val_dtype in [np.float32]
      for shape in [(3,), (5, 3)]
      for axis in [len(shape) - 1]
      for is_stable in [False, True]))
  def testSortKeyValGrad(self, shape, key_dtype, val_dtype, axis, is_stable):
    rng = jtu.rand_default(self.rng())
    # This test relies on the property that wherever keys are tied, values are
    # too, since we don't guarantee the same ordering of values with equal keys.
    # To avoid that case, we generate unique keys (globally in the key array).
    def args_maker():
      flat_keys = np.arange(prod(shape), dtype=key_dtype)
      keys = self.rng().permutation(flat_keys).reshape(shape)
      values = rng(shape, val_dtype)
      return keys, values
    keys, values = args_maker()

    fun = lambda keys, values: lax.sort_key_val(keys, values, axis, is_stable)
    check_grads(fun, (keys, values), 2, ["fwd", "rev"], 1e-2, 1e-2, 1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_k={}".format(
          jtu.format_shape_dtype_string(shape, dtype), k),
       "shape": shape, "dtype": dtype, "k": k}
      for dtype in [np.float32,]
      for shape in [(4,), (5, 5), (2, 1, 4)]
      for k in [1, 3]))
  def testTopKGrad(self, shape, dtype, k):
    flat_values = np.arange(prod(shape), dtype=dtype)
    values = self.rng().permutation(flat_values).reshape(shape)
    fun = lambda vs: lax.top_k(vs, k=k)[0]
    check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_axes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), idxs, axes),
       "shape": shape, "dtype": dtype, "idxs": idxs, "axes": axes}
      for dtype in float_dtypes
      for shape, idxs, axes in [
          [(3, 4, 5), (np.array([0, 2, 1]),), (0,)],
          [(3, 4, 5), (np.array([-1, -2]),), (0,)],
          [(3, 4, 5), (np.array([0, 2]), np.array([1, 3])), (0, 1)],
          [(3, 4, 5), (np.array([0, 2]), np.array([1, 3])), (0, 2)],
      ]))
  def testIndexTakeGrad(self, shape, dtype, idxs, axes):
    rng = jtu.rand_default(self.rng())
    src = rng(shape, dtype)
    index_take = lambda src: lax.index_take(src, idxs, axes)
    check_grads(index_take, (src,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), idxs, dnums,
          slice_sizes),
       "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
       "slice_sizes": slice_sizes, "rng_idx_factory": rng_idx_factory}
      for dtype in grad_float_dtypes
      for shape, idxs, dnums, slice_sizes, max_idx in [
          ((5,), np.array([[0], [2]]), lax.GatherDimensionNumbers(
            offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
            (1,), 5),
          ((10,), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
            (2,), 9),
          ((10, 5,), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
            (1, 3), 3),
      ]
      for rng_idx_factory in [partial(jtu.rand_int, high=max_idx)]))
  def testGatherGrad(self, shape, dtype, idxs, dnums, slice_sizes, rng_idx_factory):
    rng = jtu.rand_default(self.rng())
    rng_idx = rng_idx_factory(self.rng())
    idxs = rng_idx(idxs.shape, idxs.dtype)
    gather = lambda x: lax.gather(x, idxs, dimension_numbers=dnums,
                                  slice_sizes=slice_sizes)
    x = rng(shape, dtype)
    check_grads(gather, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          idxs, update_shape, dnums),
       "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
       "update_shape": update_shape, "dnums": dnums, "rng_idx_factory": rng_idx_factory}
      for dtype in grad_float_dtypes
      for arg_shape, idxs, update_shape, dnums, max_idx in [
          ((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
            update_window_dims=(), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,)), 4),
          ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(),
            scatter_dims_to_operand_dims=(0,)), 9),
          ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,)), 3),
      ]
      for rng_idx_factory in [partial(jtu.rand_int, high=max_idx)]))
  def testScatterAddGrad(self, arg_shape, dtype, idxs, update_shape, dnums,
                         rng_idx_factory):
    rng = jtu.rand_default(self.rng())
    rng_idx = rng_idx_factory(self.rng())
    idxs = rng_idx(idxs.shape, idxs.dtype)
    scatter_add = lambda x, y: lax.scatter_add(x, idxs, y,
                                               dimension_numbers=dnums)
    x = rng(arg_shape, dtype)
    y = rng(update_shape, dtype)
    check_grads(scatter_add, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          idxs, update_shape, dnums),
       "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
       "update_shape": update_shape, "dnums": dnums, "rng_idx_factory": rng_idx_factory}
      for dtype in grad_float_dtypes
      for arg_shape, idxs, update_shape, dnums, max_idx in [
          ((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
            update_window_dims=(), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,)), 4),
          ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(),
            scatter_dims_to_operand_dims=(0,)), 9),
          ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,)), 3),
      ]
      # Scatters with conflicting indices are not deterministic on GPU, so we
      # use indices that do not collide.
      for rng_idx_factory in [partial(jtu.rand_unique_int, high=max_idx)]))
  def testScatterGrad(self, arg_shape, dtype, idxs, update_shape, dnums,
                      rng_idx_factory):
    rng = jtu.rand_default(self.rng())
    rng_idx = rng_idx_factory(self.rng())
    idxs = rng_idx(idxs.shape, idxs.dtype)
    scatter = lambda x, y: lax.scatter(x, idxs, y, dimension_numbers=dnums)
    x = rng(arg_shape, dtype)
    y = rng(update_shape, dtype)
    check_grads(scatter, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)

  def testScatterGradSymbolicZeroUpdate(self):
    # https://github.com/google/jax/issues/1901
    def f(x):
      n = x.shape[0]
      y = np.arange(n, dtype=x.dtype)
      return jax.ops.index_update(x, np.diag_indices(n), y)
    rng = jtu.rand_default(self.rng())
    check_grads(f, (rng((5, 5), np.float32),), 2, ["fwd", "rev"], 1e-2, 1e-2,
                1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          idxs, update_shape, dnums),
       "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
       "update_shape": update_shape, "dnums": dnums}
      for dtype in grad_float_dtypes
      for arg_shape, idxs, update_shape, dnums in [
          ((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
            update_window_dims=(), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,))),
          ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(),
            scatter_dims_to_operand_dims=(0,))),
          ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,))),
      ]))
  def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums):
    rng = jtu.rand_default(self.rng())
    rng_idx = jtu.rand_int(self.rng(), high=max(arg_shape))
    idxs = rng_idx(idxs.shape, idxs.dtype)
    scatter_max = lambda x, y: lax.scatter_max(x, idxs, y, dnums)
    x = rng(arg_shape, dtype)
    y = rng(update_shape, dtype)
    check_grads(scatter_max, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          idxs, update_shape, dnums),
       "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
       "update_shape": update_shape, "dnums": dnums}
      for dtype in grad_float_dtypes
      for arg_shape, idxs, update_shape, dnums in [
          ((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
            update_window_dims=(), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,))),
          ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(),
            scatter_dims_to_operand_dims=(0,))),
          ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,))),
      ]))
  def testScatterMin(self, arg_shape, dtype, idxs, update_shape, dnums):
    rng = jtu.rand_default(self.rng())
    rng_idx = jtu.rand_int(self.rng(), high=max(arg_shape))
    idxs = rng_idx(idxs.shape, idxs.dtype)
    scatter_min = lambda x, y: lax.scatter_min(x, idxs, y, dnums)
    x = rng(arg_shape, dtype)
    y = rng(update_shape, dtype)
    check_grads(scatter_min, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2)

  def testStopGradient(self):
    def f(x):
      return lax.sin(x) * lax.cos(lax.stop_gradient(x))

    def f2(x, y):
      return lax.sin(x) * lax.cos(y)

    x = 3.14
    ans = api.grad(f)(x)
    expected = api.grad(f2)(x, x)
    self.assertAllClose(ans, expected)

    ans = api.grad(api.grad(f))(x)
    expected = api.grad(api.grad(f2))(x, x)
    self.assertAllClose(ans, expected)

    ans = api.grad(lambda x: lax.stop_gradient({'foo':x})['foo'])(3.)
    expected = np.array(0.0)
    self.assertAllClose(ans, expected, check_dtypes=False)

    with jax.enable_checks(False):
      with self.assertRaises(TypeError):
        lax.stop_gradient(lambda x: x)

  # TODO(mattjj): make this a more systematic test
  def testRemainder(self):
    rng = np.random.RandomState(0)
    x = rng.uniform(-0.9, 9, size=(3, 4))
    y = rng.uniform(0.7, 1.9, size=(3, 1))
    assert not set(np.unique(x)) & set(np.unique(y))
    tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3
    check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol)

    rng = np.random.RandomState(0)
    x = rng.uniform(-0.9, 9, size=(1, 4))
    y = rng.uniform(0.7, 1.9, size=(3, 4))
    assert not set(np.unique(x)) & set(np.unique(y))
    tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3
    check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol)

  def testHigherOrderGradientOfReciprocal(self):
    # Regression test for https://github.com/google/jax/issues/3136
    def inv(x):
      # N.B.: intentionally written as 1/x, not x ** -1 or reciprocal(x)
      return 1 / x
    grad_fn = jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(inv))))))
    self.assertAllClose(np.float32(0.0439453125), grad_fn(np.float32(4.)))

  def test_linear_transpose_real(self):
    f = lambda x: x.real
    transpose = api.linear_transpose(f, 1.j)
    actual, = transpose(1.)
    expected = 1.
    self.assertEqual(actual, expected)

  def test_linear_transpose_imag(self):
    f = lambda x: x.imag
    transpose = api.linear_transpose(f, 1.j)
    actual, = transpose(1.)
    expected = -1.j
    self.assertEqual(actual, expected)
Ejemplo n.º 4
0
class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
    def test_primitive_coverage(self):
        """Fail if there are JAX primitives that are not implemented."""
        # Harvest primitives from XLA translation tables
        all_primitives = (set(xla.translations)
                          | set(xla.backend_specific_translations['cpu'])
                          | set(xla.backend_specific_translations['gpu'])
                          | set(xla.backend_specific_translations['tpu'])
                          | set(xla.initial_style_translations)
                          | set(xla.parallel_translations))

        tf_impl = set(jax.experimental.jax2tf.jax2tf.tf_impl)
        tf_not_yet_impl = set(jax.experimental.jax2tf.jax2tf.tf_not_yet_impl)

        all_primitives = tuple(sorted(all_primitives, key=str))
        for p in all_primitives:
            if p in tf_not_yet_impl:
                self.assertNotIn(
                    p, tf_impl
                )  # Should not be in both tf_impl and tf_not_yet_impl
            else:
                self.assertIn(p, tf_impl)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax)
            for f_jax in [
                jnp.add, jnp.subtract, jnp.multiply, jnp.divide, jnp.less,
                jnp.less_equal, jnp.equal, jnp.greater, jnp.greater_equal,
                jnp.not_equal, jnp.maximum, jnp.minimum
            ]))
    def test_type_promotion(self, f_jax=jnp.add):
        # We only test a few types here, as tensorflow does not support many
        # types like uint* or bool in binary ops.
        types = [np.int32, np.int64, np.float32]
        for x_dtype in types:
            for y_dtype in types:
                x = np.array([1, 2], dtype=x_dtype)
                y = np.array([3, 4], dtype=y_dtype)
                self.ConvertAndCompare(f_jax, x, y, with_function=True)

    def test_concat(self):
        values = [
            np.array([1, 2], dtype=np.float32),
            np.array([1, 2], dtype=np.int32),
            np.array([1, 2], dtype=np.int8)
        ]
        f_jax = jax.jit(lambda x: jnp.concatenate(x, axis=0))
        self.ConvertAndCompare(f_jax, values, with_function=True)

    @primitive_harness.parameterized(primitive_harness.lax_pad)
    def test_pad(self, harness: primitive_harness.Harness):
        # TODO: figure out the bfloat16 story
        if harness.params["dtype"] is dtypes.bfloat16:
            raise unittest.SkipTest("bfloat16 not implemented")
        # TODO: fix pad with negative padding in XLA (fixed on 06/16/2020)
        if any([lo < 0 or hi < 0 for lo, hi, mid in harness.params["pads"]]):
            raise unittest.SkipTest("pad with negative pad not supported")
        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()),
                               with_function=False)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax)
            for f_jax in LAX_ELEMENTWISE_UNARY))
    def test_unary_elementwise(self, f_jax=lax.abs):
        x = np.array([-1.6, -1.4, -1.0, 0.0, 0.1, 0.2, 1, 1.4, 1.6],
                     dtype=np.float32)
        f_tf = tf.function(jax2tf.convert(f_jax))
        r_jax = f_jax(x)
        r_tf = f_tf(x)
        self.assertAllClose(r_jax[np.isfinite(r_jax)],
                            r_tf[np.isfinite(r_tf)],
                            atol=1e-4)

    def test_bitwise_not(self):
        x = np.array([-1, 3, -2, 0, 0, 2, 1, 3], dtype=np.int32)
        f_jax = jax.jit(lax.bitwise_not)
        f_tf = tf.function(jax2tf.convert(f_jax))
        r_jax = f_jax(x)
        r_tf = f_tf(x)
        self.assertAllClose(r_jax[np.isfinite(r_jax)],
                            r_tf[np.isfinite(r_tf)],
                            atol=1e-4)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax)
            for f_jax in LAX_ELEMENTWISE_BINARY))
    def test_binary_elementwise(self, f_jax=lax.add):
        a = np.array([-1.6, -1.4, -1.0, 0.0, 0.1, 0.2, 1, 1.4, 1.6],
                     dtype=np.float32)
        b = np.array([-1.6, 1.4, 1.0, 0.0, 0.1, 0.2, 1, 1.4, -1.6],
                     dtype=np.float32)
        f_tf = tf.function(jax2tf.convert(f_jax))
        r_jax = f_jax(a, b)
        r_tf = f_tf(a, b)
        # Jax outputs 0 and 1 instead of NaN for values outside the domain.
        # Whereas tensorflow does this for other combinations,
        if f_jax in (lax.igamma, lax.igammac):
            # Make returned array writeable.
            r_jax = np.copy(r_jax)
            r_jax[r_jax == 0] = np.nan
            r_jax[r_jax == 1] = np.nan
            r_tf = np.copy(r_tf)
            r_tf[r_tf == 0] = np.nan
            r_tf[r_tf == 1] = np.nan
        self.assertAllClose(r_jax[np.isfinite(r_jax)],
                            r_tf[np.isfinite(r_tf)],
                            atol=1e-4)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax)
            for f_jax in LAX_LOGICAL_ELEMENTWISE_BINARY))
    def test_binary_logical_elementwise(self, f_jax):
        a = np.array([1, 3, 2, 0, 0, 2, 1, 3], dtype=np.uint32)
        b = np.array([1, 2, 3, 0, 1, 0, 2, 3], dtype=np.uint32)
        f_tf = tf.function(jax2tf.convert(f_jax))
        r_jax = f_jax(a, b)
        r_tf = f_tf(a, b)
        self.assertAllClose(r_jax[np.isfinite(r_jax)],
                            r_tf[np.isfinite(r_tf)],
                            atol=1e-4)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax)
            for f_jax in LAX_LOGICAL_ELEMENTWISE_BINARY))
    def test_binary_logical_elementwise_bool(self, f_jax):
        if f_jax == lax.shift_left:
            self.skipTest("Shift of bool not supported")
        a = np.array([0, 0, 1, 1, 0, 0, 1, 1], dtype=np.bool_)
        b = np.array([0, 1, 0, 1, 0, 1, 0, 1], dtype=np.bool_)
        f_tf = tf.function(jax2tf.convert(f_jax))
        r_jax = f_jax(a, b)
        r_tf = f_tf(a, b)
        self.assertAllClose(r_jax, r_tf)

    # TODO(necula): combine tests that are identical except for the harness
    # wait until we get more experience with using harnesses.
    @primitive_harness.parameterized(primitive_harness.lax_shift_left)
    def test_shift_left(self, harness):
        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()),
                               with_function=True)

    @primitive_harness.parameterized(primitive_harness.lax_shift_right_logical)
    def test_shift_right_logical(self, harness):
        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()),
                               with_function=True)

    @primitive_harness.parameterized(
        primitive_harness.lax_shift_right_arithmetic)
    def test_shift_right_arithmetic(self, harness):
        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()),
                               with_function=True)

    @primitive_harness.parameterized(primitive_harness.lax_slice)
    def test_slice(self, harness):
        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()),
                               with_function=True)

    @primitive_harness.parameterized(primitive_harness.lax_dynamic_slice)
    def test_dynamic_slice(self, harness):
        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()),
                               with_function=True)

    @primitive_harness.parameterized(primitive_harness.lax_dynamic_update_slice
                                     )
    def test_dynamic_update_slice(self, harness):
        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()),
                               with_function=True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax)
            for f_jax in (lax.betainc, )))
    def test_trinary_elementwise(self, f_jax):
        a = np.array([-1.6, -1.4, -1.0, 0.0, 0.1, 0.3, 1, 1.4, 1.6],
                     dtype=np.float32)
        b = np.array([-1.6, 1.4, 1.0, 0.0, 0.2, 0.1, 1, 1.4, -1.6],
                     dtype=np.float32)
        c = np.array([1.0, -1.0, 2.0, 1.0, 0.3, 0.3, -1.0, 2.4, 1.6],
                     dtype=np.float32)
        f_tf = tf.function(jax2tf.convert(f_jax))
        r_jax = f_jax(a, b, c)
        r_tf = f_tf(a, b, c)
        self.assertAllClose(r_jax[np.isfinite(r_jax)],
                            r_tf[np.isfinite(r_tf)],
                            atol=1e-4)

    @primitive_harness.parameterized(primitive_harness.lax_squeeze)
    def test_squeeze(self, harness: primitive_harness.Harness):
        self.ConvertAndCompare(harness.dyn_fun,
                               *harness.dyn_args_maker(self.rng()),
                               with_function=True)

    def test_gather(self):
        values = np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32)
        indices = np.array([0, 1], dtype=np.int32)
        for axis in (0, 1):
            f_jax = jax.jit(lambda v, i: jnp.take(v, i, axis=axis))  # pylint: disable=cell-var-from-loop
            self.ConvertAndCompare(f_jax, values, indices, with_function=True)

    def test_boolean_gather(self):
        values = np.array([[True, True], [False, True], [False, False]],
                          dtype=np.bool_)
        indices = np.array([0, 1], dtype=np.int32)
        for axis in [0, 1]:
            f_jax = jax.jit(lambda v, i: jnp.take(v, i, axis=axis))  # pylint: disable=cell-var-from-loop
            self.ConvertAndCompare(f_jax, values, indices, with_function=True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax)
            for f_jax in REDUCE))
    def test_reduce_ops_with_numerical_input(self, f_jax):
        values = [np.array([1, 2, 3], dtype=np.float32)]
        self.ConvertAndCompare(f_jax, values, with_function=True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax)
            for f_jax in (jnp.cumsum, jnp.cumprod)))
    def test_cumulated_ops(self, f_jax):
        values = np.array([1, 2, 3], dtype=np.float32)
        self.ConvertAndCompare(f_jax, values, with_function=True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_{op.__name__}", op=op) for op in INDEX))
    def test_scatter_static(self, op):
        values = np.ones((5, 6), dtype=np.float32)
        update = np.float32(6.)
        f_jax = jax.jit(lambda v, u: op(v, jax.ops.index[::2, 3:], u))
        self.ConvertAndCompare(f_jax, values, update, with_function=True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax)
            for f_jax in REDUCE))
    def test_reduce_ops_with_boolean_input(self, f_jax):
        values = [np.array([True, False, True], dtype=np.bool_)]
        self.ConvertAndCompare(f_jax, values, with_function=True)

    def test_gather_rank_change(self):
        params = jnp.array([[1.0, 1.5, 2.0], [2.0, 2.5, 3.0], [3.0, 3.5, 4.0]])
        indices = jnp.array([[1, 1, 2], [0, 1, 0]])
        f_jax = jax.jit(lambda i: params[i])
        self.ConvertAndCompare(f_jax, indices, with_function=True)

    def test_prngsplit(self):
        f_jax = jax.jit(lambda key: jax.random.split(key, 2))
        for rng_key in [
                jax.random.PRNGKey(42),
                np.array([0, 0], dtype=np.uint32),
                np.array([0xFFFFFFFF, 0], dtype=np.uint32),
                np.array([0, 0xFFFFFFFF], dtype=np.uint32),
                np.array([0xFFFFFFFF, 0xFFFFFFFF], dtype=np.uint32)
        ]:
            self.ConvertAndCompare(f_jax, rng_key, with_function=True)

    def test_zeros_like(self):
        v = np.float32(2.)
        f_jax = jax.ad_util.zeros_like_jaxval
        self.ConvertAndCompare(f_jax, v)

    def test_stop_gradient(self):
        f = jax2tf.convert(lax.stop_gradient)
        self.assertEqual(f(tf.ones([])), 1.)
Ejemplo n.º 5
0
class NumpyLinalgTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (1000, 0, 0)]
                            for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    def testCholesky(self, shape, dtype, rng):
        _skip_if_unsupported_type(dtype)

        def args_maker():
            factor_shape = shape[:-1] + (2 * shape[-1], )
            a = rng(factor_shape, dtype)
            return [onp.matmul(a, np.conj(T(a)))]

        if np.issubdtype(dtype, np.complexfloating) and (
                len(shape) > 2 or jtu.device_under_test() != "cpu"):
            self.skipTest(
                "Unimplemented case for complex Cholesky decomposition.")

        self._CheckAgainstNumpy(onp.linalg.cholesky,
                                np.linalg.cholesky,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-3)
        self._CompileAndCheck(np.linalg.cholesky,
                              args_maker,
                              check_dtypes=True)

        if onp.finfo(dtype).bits == 64:
            jtu.check_grads(np.linalg.cholesky, args_maker(), order=2)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_n={}".format(jtu.format_shape_dtype_string((n, n), dtype)),
            "n":
            n,
            "dtype":
            dtype,
            "rng":
            rng
        } for n in [0, 4, 5, 25
                    ]  # TODO(mattjj): complex64 unstable on large sizes?
                            for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    def testDet(self, n, dtype, rng):
        _skip_if_unsupported_type(dtype)
        args_maker = lambda: [rng((n, n), dtype)]

        self._CheckAgainstNumpy(onp.linalg.det,
                                np.linalg.det,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-3)
        self._CompileAndCheck(np.linalg.det, args_maker, check_dtypes=True)

    def testDetOfSingularMatrix(self):
        x = np.array([[-1., 3. / 2], [2. / 3, -1.]], dtype=onp.float32)
        self.assertAllClose(onp.float32(0),
                            jsp.linalg.det(x),
                            check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for shape in [(0, 0), (1,
                                 1), (3,
                                      3), (4,
                                           4), (10,
                                                10), (200,
                                                      200), (2, 2,
                                                             2), (2, 3,
                                                                  3), (3, 2,
                                                                       2)]
                            for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    def testSlogdet(self, shape, dtype, rng):
        _skip_if_unsupported_type(dtype)
        args_maker = lambda: [rng(shape, dtype)]

        self._CheckAgainstNumpy(onp.linalg.slogdet,
                                np.linalg.slogdet,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-3)
        self._CompileAndCheck(np.linalg.slogdet, args_maker, check_dtypes=True)

    def testIssue1213(self):
        for n in range(5):
            mat = np.array(
                [onp.diag(onp.ones([5], dtype=onp.float32)) * (-.01)] * 2)
            args_maker = lambda: [mat]
            self._CheckAgainstNumpy(onp.linalg.slogdet,
                                    np.linalg.slogdet,
                                    args_maker,
                                    check_dtypes=True,
                                    tol=1e-3)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for shape in [(0, 0), (4, 4), (5, 5), (50, 50), (2, 6, 6)]
                            for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    # TODO(phawkins): enable when there is an eigendecomposition implementation
    # for GPU/TPU.
    @jtu.skip_on_devices("gpu", "tpu")
    def testEig(self, shape, dtype, rng):
        _skip_if_unsupported_type(dtype)
        n = shape[-1]
        args_maker = lambda: [rng(shape, dtype)]

        # Norm, adjusted for dimension and type.
        def norm(x):
            norm = onp.linalg.norm(x, axis=(-2, -1))
            return norm / ((n + 1) * onp.finfo(dtype).eps)

        a, = args_maker()
        w, v = np.linalg.eig(a)
        self.assertTrue(
            onp.all(norm(onp.matmul(a, v) - w[..., None, :] * v) < 100))

        self._CompileAndCheck(partial(np.linalg.eig),
                              args_maker,
                              check_dtypes=True,
                              rtol=1e-3)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for shape in [(1, 1), (4, 4), (5, 5)]
                            for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    @jtu.skip_on_devices("gpu", "tpu")
    def testEigBatching(self, shape, dtype, rng):
        _skip_if_unsupported_type(dtype)
        shape = (10, ) + shape
        args = rng(shape, dtype)
        ws, vs = vmap(np.linalg.eig)(args)
        self.assertTrue(
            onp.all(
                onp.linalg.norm(onp.matmul(args, vs) -
                                ws[..., None, :] * vs) < 1e-3))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_n={}_lower={}".format(
                jtu.format_shape_dtype_string((n, n), dtype), lower),
            "n":
            n,
            "dtype":
            dtype,
            "lower":
            lower,
            "rng":
            rng
        } for n in [0, 4, 5, 50] for dtype in float_types + complex_types
                            for lower in [False, True]
                            for rng in [jtu.rand_default()]))
    # TODO(phawkins): enable when there is an eigendecomposition implementation
    # for TPU.
    @jtu.skip_on_devices("tpu")
    def testEigh(self, n, dtype, lower, rng):
        _skip_if_unsupported_type(dtype)
        args_maker = lambda: [rng((n, n), dtype)]

        uplo = "L" if lower else "U"

        # Norm, adjusted for dimension and type.
        def norm(x):
            norm = onp.linalg.norm(x, axis=(-2, -1))
            return norm / ((n + 1) * onp.finfo(dtype).eps)

        a, = args_maker()
        a = (a + onp.conj(a.T)) / 2
        w, v = np.linalg.eigh(onp.tril(a) if lower else onp.triu(a),
                              UPLO=uplo,
                              symmetrize_input=False)
        self.assertTrue(norm(onp.eye(n) - onp.matmul(onp.conj(T(v)), v)) < 5)
        self.assertTrue(norm(onp.matmul(a, v) - w * v) < 30)

        self._CompileAndCheck(partial(np.linalg.eigh, UPLO=uplo),
                              args_maker,
                              check_dtypes=True,
                              rtol=1e-3)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_lower={}".format(
                jtu.format_shape_dtype_string(shape, dtype), lower),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng,
            "lower":
            lower
        } for shape in [(1, 1), (4, 4), (5, 5), (50, 50)]
                            for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]
                            for lower in [True, False]))
    # TODO(phawkins): enable when there is an eigendecomposition implementation
    # for TPU.
    @jtu.skip_on_devices("tpu")
    def testEighGrad(self, shape, dtype, rng, lower):
        self.skipTest("Test fails with numeric errors.")
        uplo = "L" if lower else "U"
        a = rng(shape, dtype)
        a = (a + onp.conj(a.T)) / 2
        a = onp.tril(a) if lower else onp.triu(a)
        # Gradient checks will fail without symmetrization as the eigh jvp rule
        # is only correct for tangents in the symmetric subspace, whereas the
        # checker checks against unconstrained (co)tangents.
        if dtype not in complex_types:
            f = partial(np.linalg.eigh, UPLO=uplo, symmetrize_input=True)
        else:  # only check eigenvalue grads for complex matrices
            f = lambda a: partial(
                np.linalg.eigh, UPLO=uplo, symmetrize_input=True)(a)[0]
        jtu.check_grads(f, (a, ), 2, rtol=1e-1)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_lower={}".format(
                jtu.format_shape_dtype_string(shape, dtype), lower),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng,
            "lower":
            lower,
            "eps":
            eps
        } for shape in [(1, 1), (4, 4), (5, 5), (50, 50)]
                            for dtype in complex_types
                            for rng in [jtu.rand_default()]
                            for lower in [True, False] for eps in [1e-4]))
    # TODO(phawkins): enable when there is an eigendecomposition implementation
    # for TPU.
    @jtu.skip_on_devices("tpu")
    def testEighGradVectorComplex(self, shape, dtype, rng, lower, eps):
        _skip_if_unsupported_type(dtype)
        # Special case to test for complex eigenvector grad correctness.
        # Exact eigenvector coordinate gradients are hard to test numerically for complex
        # eigensystem solvers given the extra degrees of per-eigenvector phase freedom.
        # Instead, we numerically verify the eigensystem properties on the perturbed
        # eigenvectors.  You only ever want to optimize eigenvector directions, not coordinates!
        uplo = "L" if lower else "U"
        a = rng(shape, dtype)
        a = (a + onp.conj(a.T)) / 2
        a = onp.tril(a) if lower else onp.triu(a)
        a_dot = eps * rng(shape, dtype)
        a_dot = (a_dot + onp.conj(a_dot.T)) / 2
        a_dot = onp.tril(a_dot) if lower else onp.triu(a_dot)
        # evaluate eigenvector gradient and groundtruth eigensystem for perturbed input matrix
        f = partial(np.linalg.eigh, UPLO=uplo)
        (w, v), (dw, dv) = jvp(f, primals=(a, ), tangents=(a_dot, ))
        new_a = a + a_dot
        new_w, new_v = f(new_a)
        new_a = (new_a + onp.conj(new_a.T)) / 2
        # Assert rtol eigenvalue delta between perturbed eigenvectors vs new true eigenvalues.
        RTOL = 1e-2
        assert onp.max(
            onp.abs((onp.diag(
                onp.dot(onp.conj(
                    (v + dv).T), onp.dot(new_a,
                                         (v + dv)))) - new_w) / new_w)) < RTOL
        # Redundant to above, but also assert rtol for eigenvector property with new true eigenvalues.
        assert onp.max(
            onp.linalg.norm(
                onp.abs(new_w * (v + dv) - onp.dot(new_a, (v + dv))), axis=0) /
            onp.linalg.norm(onp.abs(new_w * (v + dv)), axis=0)) < RTOL

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for shape in [(1, 1), (4, 4), (5, 5)]
                            for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    @jtu.skip_on_devices("tpu")
    def testEighBatching(self, shape, dtype, rng):
        _skip_if_unsupported_type(dtype)
        shape = (10, ) + shape
        args = rng(shape, dtype)
        args = (args + onp.conj(T(args))) / 2
        ws, vs = vmap(jsp.linalg.eigh)(args)
        self.assertTrue(
            onp.all(
                onp.linalg.norm(onp.matmul(args, vs) -
                                ws[..., None, :] * vs) < 1e-3))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_ord={}_axis={}_keepdims={}".format(
                jtu.format_shape_dtype_string(shape, dtype), ord, axis,
                keepdims),
            "shape":
            shape,
            "dtype":
            dtype,
            "axis":
            axis,
            "keepdims":
            keepdims,
            "ord":
            ord,
            "rng":
            rng
        } for axis, shape in [(None, (1, )), (None, (7, )), (None, (
            5, 8)), (0, (9, )), (0, (4, 5)), ((1, ), (
                10, 7,
                3)), ((-2, ),
                      (4, 8)), (-1, (6, 3)), ((0, 2),
                                              (3, 4,
                                               5)), ((2, 0),
                                                     (7, 8,
                                                      9)), (None, (7, 8, 11))]
                            for keepdims in [False, True] for ord in
                            ([None] if axis is None and len(shape) > 2 else
                             [None, 0, 1, 2, 3, -1, -2, -3, np.inf, -np.inf] if
                             (axis is None and len(shape) == 1
                              ) or isinstance(axis, int) or (
                                  isinstance(axis, tuple) and len(axis) == 1
                              ) else [
                                  None, 'fro', 1, 2, -1, -2, np.
                                  inf, -np.inf, 'nuc'
                              ]) for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    def testNorm(self, shape, dtype, ord, axis, keepdims, rng):
        _skip_if_unsupported_type(dtype)
        if (ord in ('nuc', 2, -2)
                and (jtu.device_under_test() != "cpu" or
                     (isinstance(axis, tuple) and len(axis) == 2))):
            raise unittest.SkipTest("No adequate SVD implementation available")

        args_maker = lambda: [rng(shape, dtype)]
        onp_fn = partial(onp.linalg.norm,
                         ord=ord,
                         axis=axis,
                         keepdims=keepdims)
        np_fn = partial(np.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
        # Older numpy versions promote to float64 unnecessarily..
        check_dtypes = numpy_version >= (1, 15)
        self._CheckAgainstNumpy(onp_fn,
                                np_fn,
                                args_maker,
                                check_dtypes=check_dtypes,
                                tol=1e-3)
        self._CompileAndCheck(np_fn, args_maker, check_dtypes=check_dtypes)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_n={}_full_matrices={}_compute_uv={}".format(
                jtu.format_shape_dtype_string(b + (m, n), dtype),
                full_matrices, compute_uv),
            "b":
            b,
            "m":
            m,
            "n":
            n,
            "dtype":
            dtype,
            "full_matrices":
            full_matrices,
            "compute_uv":
            compute_uv,
            "rng":
            rng
        } for b in [(), (3, ), (2, 3)] for m in [2, 7, 29, 53]
                            for n in [2, 7, 29, 53]
                            for dtype in float_types + complex_types
                            for full_matrices in [False, True]
                            for compute_uv in [False, True]
                            for rng in [jtu.rand_default()]))
    @jtu.skip_on_devices("tpu")
    def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, rng):
        _skip_if_unsupported_type(dtype)
        if b != () and jax.lib.version <= (0, 1, 28):
            raise unittest.SkipTest("Batched SVD requires jaxlib 0.1.29")
        args_maker = lambda: [rng(b + (m, n), dtype)]

        # Norm, adjusted for dimension and type.
        def norm(x):
            norm = onp.linalg.norm(x, axis=(-2, -1))
            return norm / (max(m, n) * onp.finfo(dtype).eps)

        a, = args_maker()
        out = np.linalg.svd(a,
                            full_matrices=full_matrices,
                            compute_uv=compute_uv)
        if compute_uv:
            # Check the reconstructed matrices
            if full_matrices:
                k = min(m, n)
                if m < n:
                    self.assertTrue(
                        onp.all(
                            norm(a -
                                 onp.matmul(out[1][..., None, :] *
                                            out[0], out[2][..., :k, :])) < 50))
                else:
                    self.assertTrue(
                        onp.all(
                            norm(a - onp.matmul(
                                out[1][..., None, :] *
                                out[0][..., :, :k], out[2])) < 350))
            else:
                self.assertTrue(
                    onp.all(
                        norm(a - onp.matmul(out[1][..., None, :] *
                                            out[0], out[2])) < 300))

            # Check the unitary properties of the singular vector matrices.
            self.assertTrue(
                onp.all(
                    norm(
                        onp.eye(out[0].shape[-1]) -
                        onp.matmul(onp.conj(T(out[0])), out[0])) < 10))
            if m >= n:
                self.assertTrue(
                    onp.all(
                        norm(
                            onp.eye(out[2].shape[-1]) -
                            onp.matmul(onp.conj(T(out[2])), out[2])) < 10))
            else:
                self.assertTrue(
                    onp.all(
                        norm(
                            onp.eye(out[2].shape[-2]) -
                            onp.matmul(out[2], onp.conj(T(out[2])))) < 20))

        else:
            self.assertTrue(
                onp.allclose(onp.linalg.svd(a, compute_uv=False),
                             onp.asarray(out),
                             atol=1e-4,
                             rtol=1e-4))

        self._CompileAndCheck(partial(np.linalg.svd,
                                      full_matrices=full_matrices,
                                      compute_uv=compute_uv),
                              args_maker,
                              check_dtypes=True)
        if not full_matrices:
            svd = partial(np.linalg.svd, full_matrices=False)
            jtu.check_jvp(svd, partial(jvp, svd), (a, ), rtol=1e-2, atol=1e-1)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_fullmatrices={}".format(
                jtu.format_shape_dtype_string(shape, dtype), full_matrices),
            "shape":
            shape,
            "dtype":
            dtype,
            "full_matrices":
            full_matrices,
            "rng":
            rng
        } for shape in [(1, 1), (3, 3), (3, 4), (2, 10, 5), (2, 200, 100)]
                            for dtype in float_types + complex_types
                            for full_matrices in [False, True]
                            for rng in [jtu.rand_default()]))
    def testQr(self, shape, dtype, full_matrices, rng):
        _skip_if_unsupported_type(dtype)
        if (onp.issubdtype(dtype, onp.complexfloating)
                and (jtu.device_under_test() == "tpu" or jax.lib.version <=
                     (0, 1, 27))):
            raise unittest.SkipTest("No complex QR implementation")
        m, n = shape[-2:]

        if full_matrices:
            mode, k = "complete", m
        else:
            mode, k = "reduced", min(m, n)

        a = rng(shape, dtype)
        lq, lr = np.linalg.qr(a, mode=mode)

        # onp.linalg.qr doesn't support batch dimensions. But it seems like an
        # inevitable extension so we support it in our version.
        nq = onp.zeros(shape[:-2] + (m, k), dtype)
        nr = onp.zeros(shape[:-2] + (k, n), dtype)
        for index in onp.ndindex(*shape[:-2]):
            nq[index], nr[index] = onp.linalg.qr(a[index], mode=mode)

        max_rank = max(m, n)

        # Norm, adjusted for dimension and type.
        def norm(x):
            n = onp.linalg.norm(x, axis=(-2, -1))
            return n / (max_rank * onp.finfo(dtype).eps)

        def compare_orthogonal(q1, q2):
            # Q is unique up to sign, so normalize the sign first.
            sum_of_ratios = onp.sum(onp.divide(q1, q2), axis=-2, keepdims=True)
            phases = onp.divide(sum_of_ratios, onp.abs(sum_of_ratios))
            q1 *= phases
            self.assertTrue(onp.all(norm(q1 - q2) < 30))

        # Check a ~= qr
        self.assertTrue(onp.all(norm(a - onp.matmul(lq, lr)) < 30))

        # Compare the first 'k' vectors of Q; the remainder form an arbitrary
        # orthonormal basis for the null space.
        compare_orthogonal(nq[..., :k], lq[..., :k])

        # Check that q is close to unitary.
        self.assertTrue(
            onp.all(norm(onp.eye(k) - onp.matmul(onp.conj(T(lq)), lq)) < 5))

        if not full_matrices and m >= n:
            jtu.check_jvp(np.linalg.qr,
                          partial(jvp, np.linalg.qr), (a, ),
                          atol=1e-3)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for shape in [(10, 4, 5), (5, 3, 3), (7, 6, 4)]
                            for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    def testQrBatching(self, shape, dtype, rng):
        args = rng(shape, np.float32)
        qs, rs = vmap(jsp.linalg.qr)(args)
        self.assertTrue(
            onp.all(onp.linalg.norm(args - onp.matmul(qs, rs)) < 1e-3))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_lhs={}_rhs={}".format(
                jtu.format_shape_dtype_string(lhs_shape, dtype),
                jtu.format_shape_dtype_string(rhs_shape, dtype)),
            "lhs_shape":
            lhs_shape,
            "rhs_shape":
            rhs_shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for lhs_shape, rhs_shape in [
            ((1, 1), (1, 1)),
            ((4, 4), (4, )),
            ((8, 8), (8, 4)),
            ((1, 2, 2), (3, 2)),
            ((2, 1, 3, 3), (2, 4, 3, 4)),
        ] for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    def testSolve(self, lhs_shape, rhs_shape, dtype, rng):
        _skip_if_unsupported_type(dtype)
        args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]

        self._CheckAgainstNumpy(onp.linalg.solve,
                                np.linalg.solve,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-3)
        self._CompileAndCheck(np.linalg.solve, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (5, 5, 5)]
                            for dtype in float_types
                            for rng in [jtu.rand_default()]))
    def testInv(self, shape, dtype, rng):
        _skip_if_unsupported_type(dtype)
        if jtu.device_under_test() == "gpu" and shape == (200, 200):
            raise unittest.SkipTest("Test is flaky on GPU")

        def args_maker():
            invertible = False
            while not invertible:
                a = rng(shape, dtype)
                try:
                    onp.linalg.inv(a)
                    invertible = True
                except onp.linalg.LinAlgError:
                    pass
            return [a]

        self._CheckAgainstNumpy(onp.linalg.inv,
                                np.linalg.inv,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-3)
        self._CompileAndCheck(np.linalg.inv, args_maker, check_dtypes=True)

    # Regression test for incorrect type for eigenvalues of a complex matrix.
    @jtu.skip_on_devices("tpu"
                         )  # TODO(phawkins): No eigh implementation on TPU.
    def testIssue669(self):
        def test(x):
            val, vec = np.linalg.eigh(x)
            return np.real(np.sum(val))

        grad_test_jc = jit(grad(jit(test)))
        xc = onp.eye(3, dtype=onp.complex)
        self.assertAllClose(xc, grad_test_jc(xc), check_dtypes=True)

    def testIssue1151(self):
        A = np.array(onp.random.randn(100, 3, 3), dtype=np.float32)
        b = np.array(onp.random.randn(100, 3), dtype=np.float32)
        x = np.linalg.solve(A, b)
        self.assertAllClose(vmap(np.dot)(A, x),
                            b,
                            atol=1e-3,
                            rtol=1e-3,
                            check_dtypes=True)
        jac0 = jax.jacobian(np.linalg.solve, argnums=0)(A, b)
        jac1 = jax.jacobian(np.linalg.solve, argnums=1)(A, b)
        jac0 = jax.jacobian(np.linalg.solve, argnums=0)(A[0], b[0])
        jac1 = jax.jacobian(np.linalg.solve, argnums=1)(A[0], b[0])
Ejemplo n.º 6
0
class HostCallbackTest(jtu.JaxTestCase):
    def setUp(self):
        testing_stream.reset()
        testing_stream.test_method_name = self._testMethodName
        self.old_flags = os.getenv("XLA_FLAGS", "")

    def tearDown(self) -> None:
        if os.getenv("XLA_FLAGS") != self.old_flags:
            os.environ["XLA_FLAGS"] = self.old_flags
            xla_bridge.get_backend.cache_clear()
        hcb.barrier_wait()

    def helper_set_devices(self, nr_devices):
        flags_str = os.getenv("XLA_FLAGS", "")
        os.environ["XLA_FLAGS"] = (
            flags_str +
            " --xla_force_host_platform_device_count={}".format(nr_devices))
        # Clear any cached backends so new CPU backend will pick up the env var.
        xla_bridge.get_backend.cache_clear()
        return api.devices()

    def helper_set_hlo_dump(self):
        flags_str = os.getenv("XLA_FLAGS", "")
        os.environ["XLA_FLAGS"] = f"{flags_str} --xla_dump_to=/tmp/xla_dump"
        # Clear any cached backends so new CPU backend will pick up the env var.
        xla_bridge.get_backend.cache_clear()

    def test_eval(self):
        # TODO: renable jaxpr golden tests when changing host_callback
        #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(fun1)(5.)))

        self.assertAllClose((5. * 2.)**2, fun1(5.))
        hcb.barrier_wait()
        assertMultiLineStrippedEqual(
            self, """
what: a * 2
10.00
what: y * 3
30.00""", testing_stream.output)
        testing_stream.reset()

    def test_with_tuple_results(self):
        def func2(x):
            x1, y1 = hcb.id_print((x * 2., x * 3.),
                                  output_stream=testing_stream)
            return x1 + y1

        #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(func2)(3.)))
        self.assertEqual(3. * (2. + 3.), func2(3.))
        hcb.barrier_wait()

        assertMultiLineStrippedEqual(self, """
[ 6.00
  9.00 ]""", testing_stream.output)
        testing_stream.reset()

    def test_with_dict_results(self):
        def func2(x):
            res = hcb.id_print(dict(a=x * 2., b=x * 3.),
                               output_stream=testing_stream)
            return res["a"] + res["b"]

        self.assertEqual(3. * (2. + 3.), func2(3.))
        hcb.barrier_wait()
        assertMultiLineStrippedEqual(self, """
{ a=6.00
  b=9.00 }""", testing_stream.output)
        testing_stream.reset()

    def test_with_result(self):
        def func2(x):
            x1 = hcb.id_print((x * 2., x * 3.),
                              result=x * 4.,
                              output_stream=testing_stream)
            return x1

        self.assertEqual(3. * 4., func2(3.))
        hcb.barrier_wait()
        assertMultiLineStrippedEqual(self, """
[ 6.00
  9.00 ]""", testing_stream.output)
        testing_stream.reset()

    def test_eval_tap_exception(self):
        # Simulate a tap error
        def tap_err(*args, **kwargs):
            raise NotImplementedError

        def func(x):
            x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
            x2 = hcb.id_tap(tap_err, x1 + 1, what="err")
            x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
            return x3

        with self.assertRaises(hcb.TapFunctionException):
            func(0)
            hcb.barrier_wait()

        # We should have received everything before the error
        assertMultiLineStrippedEqual(self, """
what: x1
1
what: x3
3""", testing_stream.output)
        testing_stream.reset()

    def test_jit_simple(self):
        jit_fun1 = api.jit(lambda x: 3. * hcb.id_print(
            2. * x, what="here", output_stream=testing_stream))
        self.assertAllClose(6. * 5., jit_fun1(5.))
        hcb.barrier_wait()
        assertMultiLineStrippedEqual(self, """
what: here
10.00""", testing_stream.output)
        testing_stream.reset()

    def test_jit_constant(self):
        def func(x):
            return hcb.id_print(42, result=x, output_stream=testing_stream)

        #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(api.jit(func))(5)))

        self.assertAllClose(5, api.jit(func)(5))
        hcb.barrier_wait()
        assertMultiLineStrippedEqual(self, """
42""", testing_stream.output)
        testing_stream.reset()

    def test_jit_sequence1(self):
        def func(x):
            x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
            return hcb.id_print(x1 + 1,
                                where="2",
                                output_stream=testing_stream)

        logging.info("%s: %s", self._testMethodName, api.make_jaxpr(func)(1))
        logging.info("%s: %s", self._testMethodName,
                     api.xla_computation(func)(1).as_hlo_text())
        self.assertEqual(2, api.jit(func)(1))
        hcb.barrier_wait()

        assertMultiLineStrippedEqual(self, """
where: 1
1
where: 2
2""", testing_stream.output)
        testing_stream.reset()

    def test_jit2(self):
        """A sequence of JIT."""
        def func(x):
            x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
            x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)
            return x2

        self.assertEqual(2, api.jit(func)(1))
        self.assertEqual(11, api.jit(func)(10))
        hcb.barrier_wait()
        assertMultiLineStrippedEqual(
            self, """
where: 1
1
where: 2
2
where: 1
10
where: 2
11""", testing_stream.output)
        testing_stream.reset()

    def test_jit_nested(self):
        def func(x):
            x1 = hcb.id_print(x, where="1", output_stream=testing_stream)

            def func_nested(x):
                x2 = hcb.id_print(x + 1,
                                  where="nested",
                                  output_stream=testing_stream)
                return x2

            x3 = api.jit(func_nested)(x1)
            return hcb.id_print(x3 + 1,
                                where="3",
                                output_stream=testing_stream)

        self.assertEqual(3, api.jit(func)(1))
        hcb.barrier_wait()
        assertMultiLineStrippedEqual(
            self, """
where: 1
1
where: nested
2
where: 3
3""", testing_stream.output)
        testing_stream.reset()

    def test_jit_devices(self):
        """Running on multiple devices."""
        devices = api.local_devices()
        logging.info(f"{self._testMethodName}: has devices {devices}")

        def func(x, device_id):
            x1 = hcb.id_print(x,
                              dev=str(device_id),
                              output_stream=testing_stream)
            x2 = hcb.id_print(x1 + 1,
                              dev=str(device_id),
                              output_stream=testing_stream)
            return x2

        for d in devices:
            self.assertEqual(
                112,
                api.jit(func, device=d, static_argnums=1)(111, d.id))
        hcb.barrier_wait()
        logging.info(
            f"{self._testMethodName}: found output {testing_stream.output}")
        self.assertEqual(len(devices),
                         len(re.findall(r"111", testing_stream.output)))
        self.assertEqual(len(devices),
                         len(re.findall(r"112", testing_stream.output)))
        testing_stream.reset()

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_with_jit_{with_jit}", with_jit=with_jit)
            for with_jit in [True, False]))
    def test_pytree(self, with_jit=False):
        def func(x, what=""):
            """Returns some pytrees depending on x"""
            if what == "pair_1_x":
                return (1, x)
            elif what == "pair_x_2x":
                return (x, 2 * x)
            elif what == "dict":
                return dict(a=2 * x, b=3 * x)
            else:
                assert False

        tap_count = 0

        def tap_func(a, what=""):
            nonlocal tap_count
            tap_count += 1
            self.assertEqual(func(5, what), a)

        transform = api.jit if with_jit else lambda f: f
        for what in ("pair_1_x", "pair_x_2x", "dict"):
            self.assertEqual(
                func(10, what),
                transform(lambda x: hcb.id_tap(tap_func,
                                               func(x, what),
                                               result=func(x * 2, what),
                                               what=what))(5))
        hcb.barrier_wait()  # Wait for receivers to be done
        self.assertEqual(3, tap_count)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_concurrent_{concurrent}",
                 concurrent=concurrent) for concurrent in [True, False]))
    def test_multiple_tap(self, concurrent=False):
        """Call id_tap multiple times, concurrently or in sequence. """
        if concurrent and jtu.device_under_test() == "gpu":
            # TODO(necula): it seems that on GPU if multiple host threads run
            # a jit computation, the mutliple computations are interleaved on the
            # GPU. This can result in the outfeed trains being interleaved, which
            # will trigger an error. The solution is to fix on GPU the receiving
            # logic so that we can outfeed the train as one tuple, and receive it
            # one piece as a time. Then the trains should be atomic.
            # See also b/160692602.
            raise SkipTest("concurrent id_tap not supported on GPU")
        received = set()
        count = 5

        def pause_tap(idx, **kwargs):
            received.add(int(idx))
            logging.info(f"Starting do_tap {idx}. Sleeping 1sec ...")
            time.sleep(0.3)
            logging.info(f"Finish do_tap {idx}")

        def do_tap(idx):
            api.jit(lambda idx: hcb.id_tap(pause_tap, idx))(idx)

        if concurrent:
            threads = [
                threading.Thread(name=f"enqueue_tap_{idx}",
                                 target=do_tap,
                                 args=(idx, )) for idx in range(count)
            ]
            [t.start() for t in threads]
            [t.join() for t in threads]
        else:
            for idx in range(count):
                do_tap(idx)

        hcb.barrier_wait()
        self.assertEqual(received, set(range(count)))

    # TODO(necula): see comment for test_multiple_tap.
    @jtu.skip_on_devices("gpu")
    def test_multiple_barriers(self):
        """Call barrier_wait concurrently."""
        def pause_tap(*args, **kwargs):
            logging.info("pause_tap waiting")
            time.sleep(0.3)
            logging.info("pause_tap done")

        def long_run(x):
            return hcb.id_tap(pause_tap, x)

        api.jit(long_run)(5.)

        def try_barrier(idx):
            logging.info(f"Starting test barrier {idx}")
            hcb.barrier_wait()
            logging.info(f"Finished test barrier {idx}")

        threads = [
            threading.Thread(name=f"barrier_{idx}",
                             target=try_barrier,
                             args=(idx, )) for idx in range(3)
        ]
        [t.start() for t in threads]
        [t.join() for t in threads]

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_with_jit_{with_jit}", with_jit=with_jit)
            for with_jit in [True, False]))
    def test_cond(self, with_jit=False):
        """A conditional"""
        def func(x):
            x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
            x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)

            x4 = lax.cond(
                x % 2 == 0, lambda x: hcb.id_print(
                    x, where="cond_t", output_stream=testing_stream),
                lambda x: hcb.id_print(
                    -1, where="cond_f", result=x, output_stream=testing_stream
                ), x2 + 1)
            x5 = hcb.id_print(x4 + 1,
                              where="end",
                              output_stream=testing_stream)
            return x5

        transform = api.jit if with_jit else lambda f: f
        self.assertEqual(4, transform(func)(1))
        hcb.barrier_wait()
        assertMultiLineStrippedEqual(
            self, """
where: 1
1
where: 2
2
where: cond_f
-1
where: end
4""", testing_stream.output)
        testing_stream.reset()

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_with_jit_{with_jit}", with_jit=with_jit)
            for with_jit in [True, False]))
    def test_while_cond(self, with_jit=False):
        def func(x):
            x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
            x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)

            def body(x):
                x3 = hcb.id_print(x,
                                  where="w_b_1",
                                  output_stream=testing_stream)
                x4 = lax.cond(
                    x % 2 == 0, lambda x: hcb.id_print(
                        x, where="w_b_t", output_stream=testing_stream),
                    lambda x: hcb.id_print(-1,
                                           where="w_b_f",
                                           result=x,
                                           output_stream=testing_stream),
                    x3 + 1)
                return hcb.id_print(x4,
                                    where="w_b_2",
                                    output_stream=testing_stream)

            x10 = lax.while_loop(lambda x: x <= 3, body, x2)
            res = hcb.id_print(x10, where="end", output_stream=testing_stream)
            return res

        transform = api.jit if with_jit else lambda f: f
        self.assertEqual(4, transform(func)(1))
        hcb.barrier_wait()
        assertMultiLineStrippedEqual(
            self, """
where: 1
1
where: 2
2
where: w_b_1
2
where: w_b_t
3
where: w_b_2
3
where: w_b_1
3
where: w_b_f
-1
where: w_b_2
4
where: end
4""", testing_stream.output)
        testing_stream.reset()

    def test_jit_while_pred_tap(self):
        """While with printing in the conditional."""
        def func(x):
            x1 = hcb.id_print(x, where="1")
            x10 = lax.while_loop(
                lambda x: hcb.id_print(
                    x < 3, where="w_p", output_stream=testing_stream),
                lambda x: hcb.id_print(
                    x + 1, where="w_b", output_stream=testing_stream), x1)
            res = hcb.id_print(x10, where="3", output_stream=testing_stream)
            return res

        self.assertEqual(3, api.jit(func)(1))
        hcb.barrier_wait()
        assertMultiLineStrippedEqual(
            self, """
where: w_p
True
where: w_b
2
where: w_p
True
where: w_b
3
where: w_p
False
where: 3
3""", testing_stream.output)
        testing_stream.reset()

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_with_jit_{with_jit}", with_jit=with_jit)
            for with_jit in [True, False]))
    def test_scan_cond(self, with_jit=False):
        def func(x):
            x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
            x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)

            def body(c, x):
                x3 = hcb.id_print(x, where="s_1", output_stream=testing_stream)
                x4 = lax.cond(
                    x % 2 == 0, lambda x: hcb.id_print(
                        x, where="s_t", output_stream=testing_stream),
                    lambda x: hcb.id_print(-1,
                                           where="s_f",
                                           result=x,
                                           output_stream=testing_stream),
                    x3 + 1)
                return (c,
                        hcb.id_print(x4,
                                     where="s_2",
                                     output_stream=testing_stream))

            _, x10 = lax.scan(body, x2, jnp.arange(3))
            res = hcb.id_print(x10, where="10", output_stream=testing_stream)
            return res

        if with_jit:
            func = api.jit(func)
        res = func(1)
        self.assertAllClose(jnp.array([1, 2, 3]), res)
        hcb.barrier_wait()
        assertMultiLineStrippedEqual(
            self, """
where: 1
1
where: 2
2
where: s_1
0
where: s_t
1
where: s_2
1
where: s_1
1
where: s_f
-1
where: s_2
2
where: s_1
2
where: s_t
3
where: s_2
3
where: 10
[1 2 3]""", testing_stream.output)
        testing_stream.reset()

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(
                testcase_name=f"_shape_{shape}_dtype_{dtype}_nr_args={nr_args}",
                shape=shape,
                dtype=dtype,
                nr_args=nr_args) for nr_args in [1, 2]
            for shape in [(), (2, ), (2, 3), (2, 3, 4)]
            for dtype in jtu.dtypes.all))
    def test_jit_types(self, nr_args=2, dtype=jnp.int16, shape=(2, )):
        if dtype in (jnp.complex64, jnp.complex128, jnp.bool_):
            raise SkipTest(f"id_print jit not implemented for {dtype}.")
        if jtu.device_under_test() == "tpu":
            if dtype in (jnp.int16, ):
                raise SkipTest(f"transfering {dtype} not supported on TPU")
        args = [jnp.arange(np.prod(shape), dtype=dtype).reshape(shape)]
        if nr_args > 1:
            args = args * nr_args
        jit_fun1 = api.jit(lambda xs: hcb.id_print(
            xs,
            a_new_test="************",
            testcase_name=f"shape_{shape}_dtype_{dtype}_nr_args={nr_args}"))

        res = jit_fun1(args)
        self.assertAllClose(args, res)

    def test_jit_large(self):
        arg = jnp.arange(10000, dtype=jnp.int32).reshape((10, 10, 5, -1))
        api.jit(hcb.id_print)(arg)

    def test_jit_several_together(self):
        arg = jnp.arange(50, dtype=jnp.int32).reshape((10, 5))
        api.jit(lambda x, y: hcb.id_print((x, y, x * 2.)))(
            arg, jnp.ones(100, dtype=jnp.int32))

    def test_jit_interleaving(self):
        # Several jit's without data dependencies; they may interfere
        count = 0  # Count tap invocations
        nr_arrays = 5

        def tap_func(arg, **_):
            nonlocal count
            assert len(arg) == nr_arrays
            count += 1

        # This is the function that we'll run multiple times
        def func(x, count):
            for i in range(count):
                x = hcb.id_tap(tap_func, [x + i for i in range(nr_arrays)],
                               i=i)[-1]
            return x

        x = jnp.array(1, dtype=np.int32)
        res = 0
        for _ in range(10):
            # No dependencies between the jit invocations
            res += api.jit(lambda x: func(x, 10))(x)
        hcb.barrier_wait()
        self.assertEqual(100, count)

    def test_jit_tap_exception(self):
        # Simulate a tap error
        def tap_err(*args, **kwargs):
            raise NotImplementedError

        def func(x):
            x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
            x2 = hcb.id_tap(tap_err, x1 + 1, what="err")
            x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
            return x3

        res = api.jit(func)(0)  # No error yet
        with self.assertRaises(hcb.TapFunctionException):
            hcb.barrier_wait()

        # Even though the receiver thread raised, the main thread should still
        # return 3.
        self.assertEqual(3, res)
        # We should have received all others
        assertMultiLineStrippedEqual(self, """
what: x1
1
what: x3
3""", testing_stream.output)
        testing_stream.reset()

    def test_jit_nested_cond_no_print(self):
        """A nested conditional, without any prints"""
        raise SkipTest("skip this")

        @api.jit
        def cfun(x):
            return lax.cond(
                lax.lt(x, 2), lambda x: x,
                lambda x: lax.cond(x < 5, 3, lambda x: x, 4, lambda y: y), x)

        print(self._testMethodName, api.xla_computation(cfun)(1).as_hlo_text())
        cfun(1)

    def test_while(self):
        """Executing while, even without JIT uses compiled code"""
        y = jnp.ones(5)  # captured const

        def func(x):
            return lax.while_loop(
                lambda c: c[1] < 5, lambda c:
                (y, hcb.id_print(c[1], output_stream=testing_stream) + 1),
                (x, 1))

        func(y)
        hcb.barrier_wait()
        assertMultiLineStrippedEqual(self, """
1
2
3
4""", testing_stream.output)
        testing_stream.reset()

    def test_jvp(self):
        jvp_fun1 = lambda x, xt: api.jvp(fun1, (x, ), (xt, ))
        #assertMultiLineStrippedEqual(self, "")
        res_primals, res_tangents = jvp_fun1(jnp.float32(5.), jnp.float32(0.1))
        self.assertAllClose(100., res_primals, check_dtypes=False)
        self.assertAllClose(4., res_tangents, check_dtypes=False)
        hcb.barrier_wait()
        assertMultiLineStrippedEqual(
            self, """
what: a * 2
10.00
transforms: ({'name': 'jvp'},) what: a * 2
0.20
what: y * 3
30.00
transforms: ({'name': 'jvp'},) what: y * 3
0.60""", testing_stream.output)
        testing_stream.reset()

    def test_grad_primal_unused(self):
        # The output of id_print is not needed for backwards pass
        def func(x):
            return 2. * hcb.id_print(
                x * 3., what="x * 3", output_stream=testing_stream)

        grad_func = api.grad(func)
        jaxpr = str(api.make_jaxpr(grad_func)(5.))
        # Just making the Jaxpr invokes the id_print once
        hcb.barrier_wait()
        assertMultiLineStrippedEqual(
            self, """
{ lambda  ; a.
  let
  in (6.00,) }""", jaxpr)
        assertMultiLineStrippedEqual(
            self, """
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 3
2.00""", testing_stream.output)
        testing_stream.reset()

        res_grad = grad_func(jnp.float32(5.))
        hcb.barrier_wait()

        self.assertAllClose(6., res_grad, check_dtypes=False)
        assertMultiLineStrippedEqual(
            self, """
what: x * 3
15.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 3
2.00""", testing_stream.output)
        testing_stream.reset()

    def test_grad_simple(self):
        def func(x):
            y = hcb.id_print(x * 2.,
                             what="x * 2",
                             output_stream=testing_stream)
            return x * hcb.id_print(
                y * 3., what="y * 3", output_stream=testing_stream)

        grad_func = api.grad(func)
        #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(grad_func)(5.)))

        res_grad = grad_func(jnp.float32(5.))
        self.assertAllClose(2. * 5. * 6., res_grad, check_dtypes=False)
        hcb.barrier_wait()
        assertMultiLineStrippedEqual(
            self, """
what: x * 2
10.00
what: y * 3
30.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: y * 3
5.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
15.00""", testing_stream.output)
        testing_stream.reset()

    def test_grad_double(self):
        def func(x):
            y = hcb.id_print(x * 2.,
                             what="x * 2",
                             output_stream=testing_stream)
            return x * (y * 3.)

        grad_func = api.grad(api.grad(func))
        # Just making the Jaxpr invokes the id_print twice
        _ = api.make_jaxpr(grad_func)(5.)
        hcb.barrier_wait()
        assertMultiLineStrippedEqual(
            self, """
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
3.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}, {'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
2.00""", testing_stream.output)
        testing_stream.reset()
        res_grad = grad_func(jnp.float32(5.))

        self.assertAllClose(12., res_grad, check_dtypes=False)
        hcb.barrier_wait()
        assertMultiLineStrippedEqual(
            self, """
what: x * 2
10.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
15.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}, {'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
2.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
3.00""", testing_stream.output)
        testing_stream.reset()

    def test_vmap(self):
        vmap_fun1 = api.vmap(fun1)
        vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
        #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(vmap_fun1)(vargs)))
        vmap_fun1(vargs)
        hcb.barrier_wait()
        assertMultiLineStrippedEqual(
            self, """
transforms: ({'name': 'batch', 'batch_dims': (0,)},) what: a * 2
[ 8.00 10.00]
transforms: ({'name': 'batch', 'batch_dims': (0, 0)},) what: y * 3
[24.00 30.00]""", testing_stream.output)
        testing_stream.reset()

    def test_vmap_not_batched(self):
        x = 3.

        def func(y):
            # x is not mapped, y is mapped
            _, y = hcb.id_print((x, y), output_stream=testing_stream)
            return x + y

        vmap_func = api.vmap(func)
        vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
        #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(vmap_func)(vargs)))
        _ = vmap_func(vargs)
        hcb.barrier_wait()
        assertMultiLineStrippedEqual(
            self, """
transforms: ({'name': 'batch', 'batch_dims': (None, 0)},)
[ 3.00
  [4.00 5.00] ]""", testing_stream.output)
        testing_stream.reset()

    def test_double_vmap(self):
        # A 2D tensor with x[i, j] = i + j using 2 vmap
        def sum(x, y):
            return hcb.id_print(x + y, output_stream=testing_stream)

        def sum_rows(xv, y):
            return api.vmap(sum, in_axes=(0, None))(xv, y)

        def sum_all(xv, yv):
            return api.vmap(sum_rows, in_axes=(None, 0))(xv, yv)

        xv = jnp.arange(5, dtype=np.int32)
        yv = jnp.arange(3, dtype=np.int32)
        #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(sum_all)(xv, yv)))
        _ = sum_all(xv, yv)
        hcb.barrier_wait()
        assertMultiLineStrippedEqual(
            self, """
transforms: ({'name': 'batch', 'batch_dims': (0,)}, {'name': 'batch', 'batch_dims': (0,)})
[[0 1 2 3 4]
 [1 2 3 4 5]
 [2 3 4 5 6]]""", testing_stream.output)
        testing_stream.reset()

    def test_vmap_while(self):
        """Vmap of while."""
        def func(x):
            # like max(x, 2)
            x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
            x2 = lax.while_loop(
                lambda x: x < 2, lambda x: hcb.id_print(
                    x + 1, where="w_b", output_stream=testing_stream), x1)
            res = hcb.id_print(x2, where="3", output_stream=testing_stream)
            return res

        inputs = np.arange(5, dtype=np.int32)
        self.assertAllClose(np.array([2, 2, 2, 3, 4]),
                            api.jit(api.vmap(func))(inputs),
                            check_dtypes=False)
        hcb.barrier_wait()
        assertMultiLineStrippedEqual(
            self, """
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: 1
[0 1 2 3 4]
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_b
[1 2 3 4 5]
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_b
[2 3 3 4 5]
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: 3
[2 2 2 3 4]""", testing_stream.output)
        testing_stream.reset()

    def test_vmap_while_tap_cond(self):
        """Vmap of while, with a tap in the conditional."""
        def func(x):
            # like max(x, 2)
            x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
            x2 = lax.while_loop(
                lambda x: hcb.id_print(
                    x < 2, where="w_c", output_stream=testing_stream),
                lambda x: hcb.id_print(
                    x + 1, where="w_b", output_stream=testing_stream), x1)
            res = hcb.id_print(x2, where="3", output_stream=testing_stream)
            return res

        inputs = np.arange(5, dtype=np.int32)
        res = api.jit(api.vmap(func))(inputs)
        hcb.barrier_wait()
        self.assertAllClose(np.array([2, 2, 2, 3, 4]), res, check_dtypes=False)
        assertMultiLineStrippedEqual(
            self, """
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: 1
[0 1 2 3 4]
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_c
[ True  True False False False]
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_b
[1 2 3 4 5]
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_c
[ True False False False False]
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_b
[2 3 3 4 5]
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_c
[False False False False False]
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: 3
[2 2 2 3 4]""", testing_stream.output)
        testing_stream.reset()

    def test_pmap(self):
        vargs = 2. + jnp.arange(api.local_device_count(), dtype=jnp.float32)

        pmap_fun1 = api.pmap(fun1, axis_name="i")
        res = pmap_fun1(vargs)
        hcb.barrier_wait()
        expected_res = jnp.stack(
            [fun1_equiv(2. + a) for a in range(api.local_device_count())])
        self.assertAllClose(expected_res, res, check_dtypes=False)

    def test_mask(self):
        # TODO(necula)
        raise SkipTest("masking has regressed")

        @functools.partial(api.mask, in_shapes=['n'], out_shape='')
        def padded_sum(x):
            return jnp.sum(
                hcb.id_print(x, what="x", output_stream=testing_stream))

        args = [jnp.arange(4)], dict(n=np.int64(2))
        assertMultiLineStrippedEqual(
            self, """
{ lambda c f ; a b.
  let d = lt c b
      e = id_tap[ func=_print
                  logical_shapes=[(Traced<ShapedArray(int32[]):JaxprTrace(level=0/0)>,)]
                  transforms=('mask',)
                  what=x ] a
      g = select d e f
      h = reduce_sum[ axes=(0,) ] g
  in (h,) }""", str(api.make_jaxpr(padded_sum)(*args)))

        _ = padded_sum(*args)
        self.assertMultiLineStrippedEqual(
            """
logical_shapes: [(2,)] transforms: ('mask',) what: x
[0 1 2 3]
   """, testing_stream.output)
        testing_stream.reset()

    def test_outfeed_receiver(self):
        """Test the deprecated outfeed_receiver"""
        with hcb.outfeed_receiver():
            self.assertAllClose((5. * 2.)**2, fun1(5.), check_dtypes=True)
        assertMultiLineStrippedEqual(
            self, """
what: a * 2
10.00
what: y * 3
30.00""", testing_stream.output)
        testing_stream.reset()

    def test_callback_delay(self):
        hcb.callback_extra = lambda dev: time.sleep(1)

        def func(x):
            for i in range(5):
                x = hcb.id_print(x * i, what="x times i")
            return x

        api.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3)))

    def test_callback_delay_barrier(self):
        hcb.callback_extra = lambda dev: time.sleep(2)

        def func(x):
            for i in range(1, 4):
                x = hcb.id_print(x * i,
                                 what="x times i",
                                 output_stream=testing_stream)
            return x

        api.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3)))
        # Wait for the results
        hcb.barrier_wait()
        expected = """
what: x times i
[[0. 1. 2.]
 [3. 4. 5.]]
what: x times i
[[ 0.  2.  4.]
 [ 6.  8. 10.]]
what: x times i
[[ 0.  6. 12.]
 [18. 24. 30.]]"""
        self.assertMultiLineStrippedEqual(expected, testing_stream.output)
        testing_stream.reset()
        # Call again
        api.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3)))
        hcb.barrier_wait()
        self.assertMultiLineStrippedEqual(expected, testing_stream.output)

    def test_error_bad_consumer_id(self):
        """Try to use reserved consumer ID 0.

    Check that we get the proper error from the runtime."""
        comp = xla_bridge.make_computation_builder(self._testMethodName)
        token = hcb.xops.CreateToken(comp)
        hcb._initialize_outfeed_receiver()  # Needed if this is the sole test
        with self.assertRaisesRegex(
                RuntimeError, "Consumer ID cannot be a reserved value: 0"):
            hcb._outfeed_receiver.receiver.add_outfeed(comp, token, 0, [
                xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32))
            ])

    def test_error_different_shapes(self):
        """Try to register different shapes for the same consumer ID."""
        comp = xla_bridge.make_computation_builder(self._testMethodName)
        token = hcb.xops.CreateToken(comp)
        hcb._initialize_outfeed_receiver()  # Needed if this is the sole test
        hcb._outfeed_receiver.receiver.add_outfeed(
            comp, token, 123,
            [xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32))])
        with self.assertRaisesRegex(
                RuntimeError,
                ".*does not match previous shape element_type.*"):
            hcb._outfeed_receiver.receiver.add_outfeed(
                comp, token, 123,
                [xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.int32))])
        with self.assertRaisesRegex(
                RuntimeError,
                ".*does not match previous shape element_type.*"):
            hcb._outfeed_receiver.receiver.add_outfeed(
                comp, token, 123,
                [xla_bridge.constant(comp, np.zeros((2, ), dtype=np.float32))])
Ejemplo n.º 7
0
class LaxRandomTest(jtu.JaxTestCase):
    def _CheckCollisions(self, samples, nbits):
        fail_prob = 0.01  # conservative bound on statistical fail prob by Chebyshev
        nitems = len(samples)
        nbins = 2**nbits
        nexpected = nbins * (1 - ((nbins - 1) / nbins)**nitems)
        ncollisions = len(np.unique(samples))
        sq_percent_deviation = ((ncollisions - nexpected) / nexpected)**2
        self.assertLess(sq_percent_deviation,
                        1 / np.sqrt(nexpected * fail_prob))

    def _CheckKolmogorovSmirnovCDF(self, samples, cdf):
        fail_prob = 0.01  # conservative bound on statistical fail prob by Kolmo CDF
        self.assertGreater(scipy.stats.kstest(samples, cdf).pvalue, fail_prob)

    def _CheckChiSquared(self, samples, pmf):
        alpha = 0.01  # significance level, threshold for p-value
        values, actual_freq = np.unique(samples, return_counts=True)
        expected_freq = pmf(values) * samples.size
        # per scipy: "A typical rule is that all of the observed and expected
        # frequencies should be at least 5."
        valid = (actual_freq > 5) & (expected_freq > 5)
        self.assertGreater(
            valid.sum(),
            1,
            msg='not enough valid frequencies for chi-squared test')
        _, p_value = scipy.stats.chisquare(actual_freq[valid],
                                           expected_freq[valid])
        self.assertGreater(p_value,
                           alpha,
                           msg=f'Failed chi-squared test with p={p_value}.\n'
                           'Expected vs. actual frequencies:\n'
                           f'{expected_freq[valid]}\n{actual_freq[valid]}')

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_dtype={}".format(np.dtype(dtype).name),
                "dtype": dtype
            } for dtype in jtu.dtypes.floating))
    def testNumpyAndXLAAgreeOnFloatEndianness(self, dtype):
        bits_dtype = np.uint32 if jnp.finfo(dtype).bits == 32 else np.uint64
        numpy_bits = np.array(1., dtype).view(bits_dtype)
        xla_bits = api.jit(lambda: lax.bitcast_convert_type(
            np.array(1., dtype), bits_dtype))()
        self.assertEqual(numpy_bits, xla_bits)

    def testThreefry2x32(self):
        # We test the hash by comparing to known values provided in the test code of
        # the original reference implementation of Threefry. For the values, see
        # https://github.com/DEShawResearch/Random123-Boost/blob/65e3d874b67aa7b3e02d5ad8306462f52d2079c0/libs/random/test/test_threefry.cpp#L30-L32
        def result_to_hex(result):
            return tuple([hex(x.copy()).rstrip("L") for x in result])

        expected = ("0x6b200159", "0x99ba4efe")
        result = random.threefry_2x32(np.uint32([0, 0]), np.uint32([0, 0]))

        self.assertEqual(expected, result_to_hex(result))

        expected = ("0x1cb996fc", "0xbb002be7")
        result = random.threefry_2x32(np.uint32([-1, -1]), np.uint32([-1, -1]))
        self.assertEqual(expected, result_to_hex(result))

        expected = ("0xc4923a9c", "0x483df7a0")
        result = random.threefry_2x32(np.uint32([0x13198a2e, 0x03707344]),
                                      np.uint32([0x243f6a88, 0x85a308d3]))
        self.assertEqual(expected, result_to_hex(result))

    def testThreefry2x32Large(self):
        n = 10000000
        result = random.threefry_2x32(
            (np.uint32(0x13198a2e), np.uint32(0x03707344)),
            jnp.concatenate([
                jnp.full((n, ), 0x243f6a88, jnp.uint32),
                jnp.full((n, ), 0x85a308d3, jnp.uint32)
            ]))
        np.testing.assert_equal(result[:n],
                                np.full((n, ), 0xc4923a9c, dtype=np.uint32))
        np.testing.assert_equal(result[n:],
                                np.full((n, ), 0x483df7a0, dtype=np.uint32))

    def testThreefry2x32Empty(self):
        # Regression test for an op-by-op crash for empty arrays in CUDA mode.
        with api.disable_jit():
            result = random.threefry_2x32(
                (np.uint32(0x13198a2e), np.uint32(0x03707344)),
                jnp.ones((
                    10,
                    0,
                ), jnp.uint32))
        np.testing.assert_equal(result, np.zeros((
            10,
            0,
        ), dtype=np.uint32))

    def testRngRandomBitsViewProperty(self):
        # TODO: add 64-bit if it ever supports this property.
        # TODO: will this property hold across endian-ness?
        N = 10
        key = random.PRNGKey(1701)
        nbits = [8, 16, 32]
        rand_bits = [
            jax._src.random._random_bits(key, n, (N * 64 // n, ))
            for n in nbits
        ]
        rand_bits_32 = np.array(
            [np.array(r).view(np.uint32) for r in rand_bits])
        assert np.all(rand_bits_32 == rand_bits_32[0])

    def testRngRandomBits(self):
        # Test specific outputs to ensure consistent random values between JAX versions.
        key = random.PRNGKey(1701)

        bits8 = jax._src.random._random_bits(key, 8, (3, ))
        expected8 = np.array([216, 115, 43], dtype=np.uint8)
        self.assertArraysEqual(bits8, expected8)

        bits16 = jax._src.random._random_bits(key, 16, (3, ))
        expected16 = np.array([41682, 1300, 55017], dtype=np.uint16)
        self.assertArraysEqual(bits16, expected16)

        bits32 = jax._src.random._random_bits(key, 32, (3, ))
        expected32 = np.array([56197195, 4200222568, 961309823],
                              dtype=np.uint32)
        self.assertArraysEqual(bits32, expected32)

        with jtu.ignore_warning(category=UserWarning,
                                message="Explicitly requested dtype.*"):
            bits64 = jax._src.random._random_bits(key, 64, (3, ))
        if FLAGS.jax_enable_x64:
            expected64 = np.array([
                3982329540505020460, 16822122385914693683, 7882654074788531506
            ],
                                  dtype=np.uint64)
        else:
            expected64 = np.array([676898860, 3164047411, 4010691890],
                                  dtype=np.uint32)
        self.assertArraysEqual(bits64, expected64)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_dtype={}".format(np.dtype(dtype).name),
                "dtype": dtype
            } for dtype in float_dtypes))
    def testRngUniform(self, dtype):
        key = random.PRNGKey(0)
        rand = lambda key: random.uniform(key, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key)
        compiled_samples = crand(key)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckCollisions(samples, jnp.finfo(dtype).nmant)
            self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_dtype={}".format(np.dtype(dtype).name),
                "dtype": dtype
            } for dtype in int_dtypes + uint_dtypes))
    def testRngRandint(self, dtype):
        lo = 5
        hi = 10

        key = random.PRNGKey(0)
        rand = lambda key: random.randint(key, (10000, ), lo, hi, dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key)
        compiled_samples = crand(key)

        for samples in [uncompiled_samples, compiled_samples]:
            self.assertTrue(np.all(lo <= samples))
            self.assertTrue(np.all(samples < hi))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_dtype={}".format(np.dtype(dtype).name),
                "dtype": dtype
            } for dtype in float_dtypes))
    def testNormal(self, dtype):
        key = random.PRNGKey(0)
        rand = lambda key: random.normal(key, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key)
        compiled_samples = crand(key)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.norm().cdf)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_dtype={}".format(np.dtype(dtype).name),
                "dtype": dtype
            } for dtype in float_dtypes))
    def testTruncatedNormal(self, dtype):
        key = random.PRNGKey(0)
        rand = lambda key: random.truncated_normal(key, -0.3, 0.3,
                                                   (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key)
        compiled_samples = crand(key)

        min_val = np.min(uncompiled_samples)
        max_val = np.max(uncompiled_samples)
        self.assertTrue(min_val > -0.3)
        self.assertTrue(max_val < 0.3)
        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckKolmogorovSmirnovCDF(
                samples,
                scipy.stats.truncnorm(-0.3, 0.3).cdf)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_dtype={}".format(np.dtype(dtype).name),
                "dtype": dtype
            } for dtype in jtu.dtypes.floating + jtu.dtypes.integer))
    def testShuffle(self, dtype):
        key = random.PRNGKey(0)
        x = np.arange(100).astype(dtype)
        rand = lambda key: random.shuffle(key, x)
        crand = api.jit(rand)

        with self.assertWarns(FutureWarning):
            perm1 = rand(key)
        with self.assertWarns(FutureWarning):
            perm2 = crand(key)

        self.assertAllClose(perm1, perm2)
        self.assertFalse(np.all(perm1 == x))  # seems unlikely!
        self.assertAllClose(np.sort(perm1), x, check_dtypes=False)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_shape={}_replace={}_weighted={}_array_input={}".format(
                np.dtype(dtype).name, shape, replace, weighted, array_input),
            "dtype":
            dtype,
            "shape":
            shape,
            "replace":
            replace,
            "weighted":
            weighted,
            "array_input":
            array_input
        } for dtype in jtu.dtypes.floating + jtu.dtypes.integer
                            for shape in [(), (5, ), (4, 5)]
                            for replace in [True, False]
                            for weighted in [True, False]
                            for array_input in [False, 'jnp', 'np']))
    def testChoice(self, dtype, shape, replace, weighted, array_input):
        N = 100
        key = random.PRNGKey(0)
        x = (N if not array_input else jnp.arange(N, dtype=dtype)
             if array_input == 'jnp' else np.arange(N, dtype=dtype))
        p = None if not weighted else jnp.arange(N)
        rand = lambda key: random.choice(key, x, shape, p=p, replace=replace)
        crand = api.jit(rand)

        sample1 = rand(key)
        sample2 = crand(key)

        self.assertEqual(shape, sample1.shape)
        if array_input == 'jnp':
            self.assertEqual(x.dtype, sample1.dtype)
        if not replace:
            assert len(np.unique(sample1)) == len(np.ravel(sample1))
        self.assertAllClose(sample1, sample2)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "dtype":
            dtype,
            "shape":
            shape
        } for dtype in jtu.dtypes.floating + jtu.dtypes.integer
                            for shape in [100, (10, 10), (10, 5, 2)]))
    def testPermutationArray(self, dtype, shape):
        key = random.PRNGKey(0)
        x = jnp.arange(np.prod(shape)).reshape(shape).astype(dtype)
        rand = lambda key: random.permutation(key, x)
        crand = api.jit(rand)

        perm1 = rand(key)
        perm2 = crand(key)

        self.assertAllClose(perm1, perm2)
        self.assertFalse(np.all(perm1 == x))  # seems unlikely!
        self.assertAllClose(np.sort(perm1.ravel()),
                            x.ravel(),
                            check_dtypes=False)
        self.assertArraysAllClose(
            x,
            jnp.arange(np.prod(shape)).reshape(shape).astype(dtype))

    def testPermutationInteger(self):
        key = random.PRNGKey(0)
        x = 100
        rand = lambda key: random.permutation(key, x)
        crand = api.jit(rand)

        perm1 = rand(key)
        perm2 = crand(key)

        self.assertAllClose(perm1, perm2)
        self.assertEqual(perm1.dtype, perm2.dtype)
        self.assertFalse(np.all(perm1 == np.arange(100)))  # seems unlikely!
        self.assertAllClose(np.sort(perm1), np.arange(100), check_dtypes=False)

    def testPermutationErrors(self):
        key = random.PRNGKey(0)
        with self.assertRaises(TypeError):
            random.permutation(key, 10.)
        with self.assertRaises(core.ConcretizationTypeError):
            api.jit(random.permutation)(key, 10)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_p={}_dtype={}".format(
                    p,
                    np.dtype(dtype).name),
                "p": p,
                "dtype": dtype
            } for p in [0.1, 0.5, 0.9] for dtype in jtu.dtypes.floating))
    def testBernoulli(self, p, dtype):
        key = random.PRNGKey(0)
        p = np.array(p, dtype=dtype)
        rand = lambda key, p: random.bernoulli(key, p, (10000, ))
        crand = api.jit(rand)

        uncompiled_samples = rand(key, p)
        compiled_samples = crand(key, p)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckChiSquared(samples, scipy.stats.bernoulli(p).pmf)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_p={}_{}_{}".format(p,
                                 np.dtype(dtype).name, sample_shape),
            "p":
            p,
            "axis":
            axis,
            "dtype":
            dtype,
            'sample_shape':
            sample_shape
        } for (p, axis) in [
            ([.25] * 4, -1),
            ([.1, .2, .3, .4], -1),
            ([[.5, .5], [.1, .9]], 1),
            ([[.5, .1], [.5, .9]], 0),
        ] for sample_shape in [(10000, ), (5000, 2)]
                            for dtype in jtu.dtypes.floating))
    def testCategorical(self, p, axis, dtype, sample_shape):
        key = random.PRNGKey(0)
        p = np.array(p, dtype=dtype)
        logits = np.log(p) - 42  # test unnormalized
        out_shape = tuple(np.delete(logits.shape, axis))
        shape = sample_shape + out_shape
        rand = partial(random.categorical, shape=shape, axis=axis)
        crand = api.jit(rand)

        uncompiled_samples = rand(key, logits)
        compiled_samples = crand(key, logits)

        if axis < 0:
            axis += len(logits.shape)

        for samples in [uncompiled_samples, compiled_samples]:
            assert samples.shape == shape
            samples = jnp.reshape(samples, (10000, ) + out_shape)
            if len(p.shape[:-1]) > 0:
                ps = np.transpose(p, (1, 0)) if axis == 0 else p
                for cat_samples, cat_p in zip(samples.transpose(), ps):
                    self._CheckChiSquared(cat_samples, pmf=lambda x: cat_p[x])
            else:
                self._CheckChiSquared(samples, pmf=lambda x: p[x])

    def testBernoulliShape(self):
        key = random.PRNGKey(0)
        x = random.bernoulli(key, np.array([0.2, 0.3]), shape=(3, 2))
        assert x.shape == (3, 2)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_a={}_b={}_dtype={}".format(a, b,
                                         np.dtype(dtype).name),
            "a":
            a,
            "b":
            b,
            "dtype":
            dtype
        } for a in [0.2, 5.] for b in [0.2, 5.] for dtype in [np.float64])
    )  # NOTE: KS test fails with float32
    def testBeta(self, a, b, dtype):
        if not FLAGS.jax_enable_x64:
            raise SkipTest("skip test except on X64")
        key = random.PRNGKey(0)
        rand = lambda key, a, b: random.beta(key, a, b, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key, a, b)
        compiled_samples = crand(key, a, b)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckKolmogorovSmirnovCDF(samples,
                                            scipy.stats.beta(a, b).cdf)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_dtype={}".format(np.dtype(dtype).name),
                "dtype": dtype
            } for dtype in float_dtypes))
    def testCauchy(self, dtype):
        key = random.PRNGKey(0)
        rand = lambda key: random.cauchy(key, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key)
        compiled_samples = crand(key)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.cauchy().cdf)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_alpha={}_dtype={}".format(alpha,
                                        np.dtype(dtype).name),
            "alpha":
            alpha,
            "dtype":
            dtype
        } for alpha in [
            np.array([0.2, 1., 5.]),
        ] for dtype in jtu.dtypes.floating))
    @jtu.skip_on_devices("tpu")  # TODO(mattjj): slow compilation times
    def testDirichlet(self, alpha, dtype):
        key = random.PRNGKey(0)
        rand = lambda key, alpha: random.dirichlet(key, alpha,
                                                   (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key, alpha)
        compiled_samples = crand(key, alpha)

        for samples in [uncompiled_samples, compiled_samples]:
            self.assertAllClose(samples.sum(-1), np.ones(10000, dtype=dtype))
            alpha_sum = sum(alpha)
            for i, a in enumerate(alpha):
                self._CheckKolmogorovSmirnovCDF(
                    samples[..., i],
                    scipy.stats.beta(a, alpha_sum - a).cdf)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_dtype={}".format(np.dtype(dtype).name),
                "dtype": dtype
            } for dtype in float_dtypes))
    def testExponential(self, dtype):
        key = random.PRNGKey(0)
        rand = lambda key: random.exponential(key, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key)
        compiled_samples = crand(key)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.expon().cdf)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_a={}_dtype={}".format(
                    a,
                    np.dtype(dtype).name),
                "a": a,
                "dtype": dtype
            } for a in [0.1, 1., 10.] for dtype in jtu.dtypes.floating))
    def testGamma(self, a, dtype):
        key = random.PRNGKey(0)
        rand = lambda key, a: random.gamma(key, a, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key, a)
        compiled_samples = crand(key, a)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gamma(a).cdf)

    def testGammaShape(self):
        key = random.PRNGKey(0)
        x = random.gamma(key, np.array([0.2, 0.3]), shape=(3, 2))
        assert x.shape == (3, 2)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_a={}".format(alpha),
            "alpha": alpha
        } for alpha in [1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]))
    def testGammaGrad(self, alpha):
        rng = random.PRNGKey(0)
        alphas = np.full((100, ), alpha)
        z = random.gamma(rng, alphas)
        actual_grad = api.grad(lambda x: random.gamma(rng, x).sum())(alphas)

        eps = 0.01 * alpha / (1.0 + np.sqrt(alpha))
        cdf_dot = (scipy.stats.gamma.cdf(z, alpha + eps) -
                   scipy.stats.gamma.cdf(z, alpha - eps)) / (2 * eps)
        pdf = scipy.stats.gamma.pdf(z, alpha)
        expected_grad = -cdf_dot / pdf

        self.assertAllClose(
            actual_grad,
            expected_grad,
            check_dtypes=True,
            rtol=2e-2 if jtu.device_under_test() == "tpu" else 7e-4)

    def testGammaGradType(self):
        # Regression test for https://github.com/google/jax/issues/2130
        key = random.PRNGKey(0)
        a = jnp.array(1., dtype=jnp.float32)
        b = jnp.array(3., dtype=jnp.float32)
        f = lambda x, y: random.gamma(key=key, a=x, dtype=jnp.float32) / y
        # Should not crash with a type error.
        api.vjp(f, a, b)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_lam={}_dtype={}".format(lam,
                                      np.dtype(dtype).name),
            "lam":
            lam,
            "dtype":
            np.dtype(dtype)
        } for lam in [0.5, 3, 9, 11, 50, 500]
                            for dtype in [np.int16, np.int32, np.int64]))
    def testPoisson(self, lam, dtype):
        key = random.PRNGKey(0)
        rand = lambda key, lam: random.poisson(key, lam, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key, lam)
        compiled_samples = crand(key, lam)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckChiSquared(samples, scipy.stats.poisson(lam).pmf)
            # TODO(shoyer): determine error bounds for moments more rigorously (e.g.,
            # based on the central limit theorem).
            self.assertAllClose(samples.mean(),
                                lam,
                                rtol=0.01,
                                check_dtypes=False)
            self.assertAllClose(samples.var(),
                                lam,
                                rtol=0.03,
                                check_dtypes=False)

    def testPoissonBatched(self):
        key = random.PRNGKey(0)
        lam = jnp.concatenate([2 * jnp.ones(10000), 20 * jnp.ones(10000)])
        samples = random.poisson(key, lam, shape=(20000, ))
        self._CheckChiSquared(samples[:10000], scipy.stats.poisson(2.0).pmf)
        self._CheckChiSquared(samples[10000:], scipy.stats.poisson(20.0).pmf)

    def testPoissonShape(self):
        key = random.PRNGKey(0)
        x = random.poisson(key, np.array([2.0, 20.0]), shape=(3, 2))
        assert x.shape == (3, 2)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_dtype={}".format(np.dtype(dtype).name),
                "dtype": dtype
            } for dtype in jtu.dtypes.floating))
    def testGumbel(self, dtype):
        key = random.PRNGKey(0)
        rand = lambda key: random.gumbel(key, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key)
        compiled_samples = crand(key)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckKolmogorovSmirnovCDF(samples,
                                            scipy.stats.gumbel_r().cdf)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_dtype={}".format(np.dtype(dtype).name),
                "dtype": dtype
            } for dtype in float_dtypes))
    def testLaplace(self, dtype):
        key = random.PRNGKey(0)
        rand = lambda key: random.laplace(key, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key)
        compiled_samples = crand(key)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.laplace().cdf)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_dtype={}".format(np.dtype(dtype).name),
                "dtype": dtype
            } for dtype in float_dtypes))
    def testLogistic(self, dtype):
        key = random.PRNGKey(0)
        rand = lambda key: random.logistic(key, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key)
        compiled_samples = crand(key)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckKolmogorovSmirnovCDF(samples,
                                            scipy.stats.logistic().cdf)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_b={}_dtype={}".format(
                    b,
                    np.dtype(dtype).name),
                "b": b,
                "dtype": dtype
            } for b in [0.1, 1., 10.] for dtype in jtu.dtypes.floating))
    def testPareto(self, b, dtype):
        key = random.PRNGKey(0)
        rand = lambda key, b: random.pareto(key, b, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key, b)
        compiled_samples = crand(key, b)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.pareto(b).cdf)

    def testParetoShape(self):
        key = random.PRNGKey(0)
        x = random.pareto(key, np.array([0.2, 0.3]), shape=(3, 2))
        assert x.shape == (3, 2)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_df={}_dtype={}".format(
                    df,
                    np.dtype(dtype).name),
                "df": df,
                "dtype": dtype
            } for df in [0.1, 1., 10.] for dtype in jtu.dtypes.floating))
    @jtu.skip_on_devices("cpu",
                         "tpu")  # TODO(phawkins): slow compilation times
    def testT(self, df, dtype):
        key = random.PRNGKey(0)
        rand = lambda key, df: random.t(key, df, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key, df)
        compiled_samples = crand(key, df)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.t(df).cdf)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_dim={}_dtype={}".format(
                    dim, np.dtype(dtype)),
                "dim": dim,
                "dtype": dtype
            } for dim in [1, 3, 5] for dtype in float_dtypes))
    def testMultivariateNormal(self, dim, dtype):
        r = np.random.RandomState(dim)
        mean = r.randn(dim)
        cov_factor = r.randn(dim, dim)
        cov = np.dot(cov_factor, cov_factor.T) + dim * np.eye(dim)

        key = random.PRNGKey(0)
        rand = partial(random.multivariate_normal,
                       mean=mean,
                       cov=cov,
                       shape=(10000, ))
        crand = api.jit(rand)

        uncompiled_samples = np.asarray(rand(key), np.float64)
        compiled_samples = np.asarray(crand(key), np.float64)

        inv_scale = scipy.linalg.lapack.dtrtri(np.linalg.cholesky(cov),
                                               lower=True)[0]
        for samples in [uncompiled_samples, compiled_samples]:
            centered = samples - mean
            whitened = np.einsum('nj,ij->ni', centered, inv_scale)

            # This is a quick-and-dirty multivariate normality check that tests that a
            # uniform mixture of the marginals along the covariance matrix's
            # eigenvectors follow a standard normal distribution.
            self._CheckKolmogorovSmirnovCDF(whitened.ravel(),
                                            scipy.stats.norm().cdf)

    @parameterized.named_parameters(jtu.cases_from_list(
        {"testcase_name": "_dim={}_mean_batch_size={}_cov_batch_size={}_shape={}"\
         .format(dim, mean_batch_size, cov_batch_size, shape),
         "dim": dim,
         "mean_batch_size": mean_batch_size,
         "cov_batch_size": cov_batch_size,
         "shape": shape}
        for dim in [1, 2, 4]
        for mean_batch_size in [(), (3,), (2, 3)]
        for cov_batch_size in [(), (3,), (2, 3)]
        for shape in [(), (1,), (5,)]))
    def testMultivariateNormalShapes(self, dim, mean_batch_size,
                                     cov_batch_size, shape):
        r = np.random.RandomState(0)
        key = random.PRNGKey(0)
        eff_batch_size = mean_batch_size \
          if len(mean_batch_size) > len(cov_batch_size) else cov_batch_size
        mean = r.randn(*(mean_batch_size + (dim, )))
        cov_factor = r.randn(*(cov_batch_size + (dim, dim)))
        cov = np.einsum('...ij,...kj->...ik', cov_factor, cov_factor)
        cov += 1e-3 * np.eye(dim)
        shape = shape + eff_batch_size
        samples = random.multivariate_normal(key, mean, cov, shape=shape)
        assert samples.shape == shape + (dim, )

    def testMultivariateNormalCovariance(self):
        # test code based on https://github.com/google/jax/issues/1869
        N = 100000
        cov = jnp.array([[0.19, 0.00, -0.13, 0.00], [0.00, 0.29, 0.00, -0.23],
                         [-0.13, 0.00, 0.39, 0.00], [0.00, -0.23, 0.00, 0.49]])
        mean = jnp.zeros(4)

        out_np = np.random.RandomState(0).multivariate_normal(mean, cov, N)

        key = random.PRNGKey(0)
        out_jnp = random.multivariate_normal(key,
                                             mean=mean,
                                             cov=cov,
                                             shape=(N, ))

        var_np = out_np.var(axis=0)
        var_jnp = out_jnp.var(axis=0)
        self.assertAllClose(var_np,
                            var_jnp,
                            rtol=1e-2,
                            atol=1e-2,
                            check_dtypes=False)

        var_np = np.cov(out_np, rowvar=False)
        var_jnp = np.cov(out_jnp, rowvar=False)
        self.assertAllClose(var_np,
                            var_jnp,
                            rtol=1e-2,
                            atol=1e-2,
                            check_dtypes=False)

    def testIssue222(self):
        x = random.randint(random.PRNGKey(10003), (), 0, 0)
        assert x == 0

    def testFoldIn(self):
        key = random.PRNGKey(0)
        keys = [random.fold_in(key, i) for i in range(10)]
        assert np.unique(np.ravel(keys)).shape == (20, )

    def testStaticShapeErrors(self):
        if config.read("jax_disable_jit"):
            raise SkipTest("test only relevant when jit enabled")

        @api.jit
        def feature_map(n, d, sigma=1.0, seed=123):
            key = random.PRNGKey(seed)
            W = random.normal(key, (d, n)) / sigma
            w = random.normal(key, (d, )) / sigma
            b = 2 * jnp.pi * random.uniform(key, (d, ))

            phi = lambda x, t: jnp.sqrt(2.0 / d) * jnp.cos(
                jnp.matmul(W, x) + w * t + b)
            return phi

        self.assertRaisesRegex(TypeError, 'Shapes must be 1D.*',
                               lambda: feature_map(5, 3))

    def testIssue756(self):
        key = random.PRNGKey(0)
        w = random.normal(key, ())
        if FLAGS.jax_enable_x64:
            self.assertEqual(np.result_type(w), np.float64)
        else:
            self.assertEqual(np.result_type(w), np.float32)

    def testIssue1789(self):
        def f(x):
            return random.gamma(random.PRNGKey(0), x)

        grad(lambda x: jnp.sum(vmap(f)(x)))(jnp.ones(2))

    def testNoOpByOpUnderHash(self):
        def fail(*args, **kwargs):
            assert False

        apply_primitive, xla.apply_primitive = xla.apply_primitive, fail
        try:
            _ = random.threefry_2x32(np.zeros(2, np.uint32),
                                     np.arange(10, dtype=np.uint32))
        finally:
            xla.apply_primitive = apply_primitive

    def testPRNGValues(self):
        # Test to ensure consistent random values between JAX versions
        k = random.PRNGKey(0)

        if FLAGS.jax_enable_x64:
            self.assertAllClose(
                random.randint(k, (3, 3), 0, 8),
                np.array([[7, 2, 6], [2, 1, 0], [6, 7, 7]], dtype='int64'))
        else:
            self.assertAllClose(
                random.randint(k, (3, 3), 0, 8),
                np.array([[2, 1, 3], [6, 1, 5], [6, 3, 4]], dtype='int32'))

        self.assertAllClose(
            random.split(k, 4),
            np.array([[2285895361, 1501764800], [1518642379, 4090693311],
                      [433833334, 4221794875], [839183663, 3740430601]],
                     dtype='uint32'))

        self.assertAllClose(random.fold_in(k, 4),
                            np.array([2285895361, 433833334], dtype='uint32'))

    def testDtypeErrorMessage(self):
        with self.assertRaisesRegex(ValueError, r"dtype argument to.*"):
            random.normal(random.PRNGKey(0), (), dtype=jnp.int32)

    def testRandomBroadcast(self):
        """Issue 4033"""
        # test for broadcast issue in https://github.com/google/jax/issues/4033
        key = random.PRNGKey(0)
        shape = (10, 2)
        x = random.uniform(key, shape, minval=jnp.zeros(2), maxval=jnp.ones(2))
        assert x.shape == shape
        x = random.randint(key, shape, jnp.array([0, 1]), jnp.array([1, 2]))
        assert x.shape == shape

    def testMaxwellSample(self):
        num_samples = 10**5
        rng = random.PRNGKey(0)

        rand = lambda x: random.maxwell(x, (num_samples, ))
        crand = api.jit(rand)

        loc = scipy.stats.maxwell.mean()
        std = scipy.stats.maxwell.std()

        uncompiled_samples = rand(rng)
        compiled_samples = crand(rng)

        for samples in [uncompiled_samples, compiled_samples]:
            # Check first and second moments.
            self.assertEqual((num_samples, ), samples.shape)
            self.assertAllClose(np.mean(samples), loc, atol=0., rtol=0.1)
            self.assertAllClose(np.std(samples), std, atol=0., rtol=0.1)
            self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.maxwell().cdf)

    @parameterized.named_parameters(('test1', 4.0, 1.0), ('test2', 2.0, 3.0))
    def testWeibullSample(self, concentration, scale):
        num_samples = 10**5
        rng = random.PRNGKey(0)

        rand = lambda x: random.weibull_min(x, scale, concentration,
                                            (num_samples, ))
        crand = api.jit(rand)

        loc = scipy.stats.weibull_min.mean(c=concentration, scale=scale)
        std = scipy.stats.weibull_min.std(c=concentration, scale=scale)

        uncompiled_samples = rand(rng)
        compiled_samples = crand(rng)

        for samples in [uncompiled_samples, compiled_samples]:
            # Check first and second moments.
            self.assertEqual((num_samples, ), samples.shape)
            self.assertAllClose(np.mean(samples), loc, atol=0., rtol=0.1)
            self.assertAllClose(np.std(samples), std, atol=0., rtol=0.1)
            self._CheckKolmogorovSmirnovCDF(
                samples,
                scipy.stats.weibull_min(c=concentration, scale=scale).cdf)

    @parameterized.named_parameters(('test1', 4.0, 1.0), ('test2', 2.0, 3.0))
    def testDoublesidedMaxwellSample(self, loc, scale):
        num_samples = 10**5
        rng = random.PRNGKey(0)

        rand = lambda key: random.double_sided_maxwell(rng, loc, scale,
                                                       (num_samples, ))
        crand = api.jit(rand)

        mean = loc
        std = np.sqrt(3.) * scale

        uncompiled_samples = rand(rng)
        compiled_samples = crand(rng)

        # Compute the double sided maxwell CDF through the one sided maxwell cdf.
        # This is done as follows:
        # P(DSM <= x) = P (loc + scale * radamacher_sample * one_sided_sample <=x) =
        # P (radamacher_sample * one_sided_sample <= (x - loc) / scale) =
        # 1/2 P(one_sided_sample <= (x - loc) / scale)
        #    + 1/2 P( - one_sided_sample <= (x - loc) / scale) =
        #  1/2 P(one_sided_sample <= (x - loc) / scale)
        #    + 1/2 P(one_sided_sample >= - (x - loc) / scale) =
        # 1/2 CDF_one_maxwell((x - loc) / scale))
        #   + 1/2 (1 - CDF_one_maxwell(- (x - loc) / scale)))
        def double_sided_maxwell_cdf(x, loc, scale):
            pos = scipy.stats.maxwell().cdf((x - loc) / scale)
            neg = (1 - scipy.stats.maxwell().cdf((-x + loc) / scale))
            return (pos + neg) / 2

        for samples in [uncompiled_samples, compiled_samples]:
            # Check first and second moments.
            self.assertEqual((num_samples, ), samples.shape)
            self.assertAllClose(np.mean(samples), mean, atol=0., rtol=0.1)
            self.assertAllClose(np.std(samples), std, atol=0., rtol=0.1)

            self._CheckKolmogorovSmirnovCDF(
                samples, lambda x: double_sided_maxwell_cdf(x, loc, scale))

    def testRadamacher(self):
        rng = random.PRNGKey(0)
        num_samples = 10**5

        rand = lambda x: random.rademacher(x, (num_samples, ))
        crand = api.jit(rand)

        uncompiled_samples = rand(rng)
        compiled_samples = crand(rng)

        for samples in [uncompiled_samples, compiled_samples]:
            unique_values, counts = np.unique(samples, return_counts=True)
            assert len(unique_values) == 2
            assert len(counts) == 2

            self.assertAllClose(counts[0] / num_samples,
                                0.5,
                                rtol=1e-02,
                                atol=1e-02)
            self.assertAllClose(counts[1] / num_samples,
                                0.5,
                                rtol=1e-02,
                                atol=1e-02)

    def testChoiceShapeIsNotSequenceError(self):
        key = random.PRNGKey(0)
        with self.assertRaises(TypeError):
            random.choice(key, 5, 2, replace=False)
        with self.assertRaises(TypeError):
            random.choice(key, 5, 2, replace=True)

    def test_eval_shape_big_random_array(self):
        def f(x):
            return random.normal(random.PRNGKey(x), (int(1e12), ))

        with core.skipping_checks():  # check_jaxpr will materialize array
            api.eval_shape(f, 0)  # doesn't error

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "seed={seed}_type={type}_jit={jit}".format(
                    **dct),
                **dct
            } for dct in [
                {
                    "seed": 0,
                    "type": int,
                    "jit": True,
                    "key": [0, 0]
                },
                {
                    "seed": 0,
                    "type": int,
                    "jit": False,
                    "key": [0, 0]
                },
                {
                    "seed": 1,
                    "type": np.int32,
                    "jit": True,
                    "key": [0, 1]
                },
                {
                    "seed": 1,
                    "type": np.int32,
                    "jit": False,
                    "key": [0, 1]
                },
                {
                    "seed": 2,
                    "type": np.uint32,
                    "jit": True,
                    "key": [0, 2]
                },
                {
                    "seed": 2,
                    "type": np.uint32,
                    "jit": False,
                    "key": [0, 2]
                },
                {
                    "seed": 3,
                    "type": np.int64,
                    "jit": True,
                    "key": [0, 3]
                },
                {
                    "seed": 3,
                    "type": np.int64,
                    "jit": False,
                    "key": [0, 3]
                },
                {
                    "seed":
                    -1,
                    "type":
                    int,
                    "jit":
                    True,
                    "key": [4294967295, 4294967295] if FLAGS.
                    jax_enable_x64 else [0, 4294967295]
                },
                {
                    "seed":
                    -1,
                    "type":
                    int,
                    "jit":
                    False,
                    "key": [4294967295, 4294967295] if FLAGS.
                    jax_enable_x64 else [0, 4294967295]
                },
                {
                    "seed": -2,
                    "type": np.int32,
                    "jit": True,
                    "key": [0, 4294967294]
                },
                {
                    "seed": -2,
                    "type": np.int32,
                    "jit": False,
                    "key": [0, 4294967294]
                },
                {
                    "seed":
                    -3,
                    "type":
                    np.int64,
                    "jit":
                    True,
                    "key": [4294967295, 4294967293] if FLAGS.
                    jax_enable_x64 else [0, 4294967293]
                },
                {
                    "seed":
                    -3,
                    "type":
                    np.int64,
                    "jit":
                    False,
                    "key": [4294967295, 4294967293] if FLAGS.
                    jax_enable_x64 else [0, 4294967293]
                },
                {
                    "seed": np.iinfo(np.int32).max + 100,
                    "type": int,
                    "jit": True,
                    "key": [0, 2147483747]
                },
                {
                    "seed": np.iinfo(np.int32).max + 100,
                    "type": int,
                    "jit": False,
                    "key": [0, 2147483747]
                },
                {
                    "seed": np.iinfo(np.int32).max + 101,
                    "type": np.uint32,
                    "jit": True,
                    "key": [0, 2147483748]
                },
                {
                    "seed": np.iinfo(np.int32).max + 101,
                    "type": np.uint32,
                    "jit": False,
                    "key": [0, 2147483748]
                },
                {
                    "seed":
                    np.iinfo(np.int32).min - 100,
                    "type":
                    int,
                    "jit":
                    True,
                    "key": [4294967295, 2147483548] if FLAGS.
                    jax_enable_x64 else [0, 2147483548]
                },
                {
                    "seed":
                    np.iinfo(np.int32).min - 100,
                    "type":
                    int,
                    "jit":
                    False,
                    "key": [4294967295, 2147483548] if FLAGS.
                    jax_enable_x64 else [0, 2147483548]
                },
                {
                    "seed":
                    np.iinfo(np.int32).min - 101,
                    "type":
                    np.int64,
                    "jit":
                    True,
                    "key": [4294967295, 2147483547] if FLAGS.
                    jax_enable_x64 else [0, 2147483547]
                },
                {
                    "seed":
                    np.iinfo(np.int32).min - 101,
                    "type":
                    np.int64,
                    "jit":
                    False,
                    "key": [4294967295, 2147483547] if FLAGS.
                    jax_enable_x64 else [0, 2147483547]
                },
            ]))
    def test_prng_seeds_and_keys(self, seed, type, jit, key):
        seed = type(seed)
        if jit:
            actual = api.jit(random.PRNGKey)(seed)
        else:
            actual = random.PRNGKey(seed)
        expected = jnp.array(key, dtype=jnp.uint32)
        self.assertArraysEqual(actual, expected)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": f"_seed={seed}_type={type}",
                "seed": seed,
                "type": type
            } for type in ["int", "np.array", "jnp.array"] for seed in
            [-1, 0, 1, (1 << 32) - 1, (1 << 63) - 1,
             np.uint64((1 << 64) - 1)]))
    def test_prng_jit_invariance(self, seed, type):
        if type == "int" and seed == (1 << 64) - 1:
            self.skipTest("Expected failure: Python int too large.")
        type = {"int": int, "np.array": np.array, "jnp.array": jnp.array}[type]
        args_maker = lambda: [type(seed)]
        self._CompileAndCheck(random.PRNGKey, args_maker)

    def test_prng_errors(self):
        seed = np.iinfo(np.uint64).max
        with self.assertRaises(OverflowError):
            random.PRNGKey(seed)
        with self.assertRaises(OverflowError):
            api.jit(random.PRNGKey)(seed)
Ejemplo n.º 8
0
class IndexingTest(jtu.JaxTestCase):
  """Tests for Numpy indexing translation rules."""

  @parameterized.named_parameters(jtu.cases_from_list({
      "testcase_name": "{}_inshape={}_indexer={}".format(
          name, jtu.format_shape_dtype_string( shape, dtype), indexer),
       "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer
  } for name, index_specs in STATIC_INDEXING_TESTS
    for shape, indexer in index_specs
    for dtype in all_dtypes
    for rng_factory in [jtu.rand_default]))
  def testStaticIndexing(self, shape, dtype, rng_factory, indexer):
    rng = rng_factory()
    args_maker = lambda: [rng(shape, dtype)]
    fun = lambda x: x[indexer]
    self._CompileAndCheck(fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters({
      "testcase_name":
          "{}_inshape={}_indexer={}".format(name,
                                            jtu.format_shape_dtype_string(
                                                shape, dtype), indexer),
      "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer
  } for name, index_specs in STATIC_INDEXING_GRAD_TESTS
    for shape, indexer in index_specs
    for dtype in float_dtypes
    for rng_factory in [jtu.rand_default])
  def testStaticIndexingGrads(self, shape, dtype, rng_factory, indexer):
    rng = rng_factory()
    tol = 1e-2 if lnp.finfo(dtype).bits == 32 else None
    arg = rng(shape, dtype)
    fun = lambda x: x[indexer]**2
    check_grads(fun, (arg,), 2, tol, tol, tol)

  def _ReplaceSlicesWithTuples(self, idx):
    """Helper method to replace slices with tuples for dynamic indexing args."""
    if isinstance(idx, slice):
      triple = idx.start, idx.stop, idx.step
      isnone = [i for i, elt in enumerate(triple) if elt is None]
      zeros = itertools.repeat(0)
      nones = itertools.repeat(None)
      out = util.subvals(triple, zip(isnone, zeros))
      return out, lambda out: slice(*util.subvals(out, zip(isnone, nones)))
    elif isinstance(idx, (tuple, list)) and idx:
      t = type(idx)
      elts, packs = zip(*map(self._ReplaceSlicesWithTuples, idx))
      return elts, lambda elts: t((pack(i) for pack, i in zip(packs, elts)))
    else:
      return idx, lambda x: x

  @parameterized.named_parameters(
      {"testcase_name": "{}_inshape={}_indexer={}"
       .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
       "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer}
      for name, index_specs in [
          ("OneSliceIndex",
           [IndexSpec(shape=(5,), indexer=slice(1, 3)),
            IndexSpec(shape=(5, 4), indexer=slice(1, 3))]),
          ("TwoSliceIndices",
           [IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))),
            IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2)))]),
          ("NonUnitStrides", [
              IndexSpec(shape=(3,), indexer=slice(None, None, -1)),
              IndexSpec(shape=(3, 3), indexer=slice(0, 3, -2)),
              IndexSpec(shape=(3, 4, 5), indexer=slice(0, 4, 2))
          ]),
          ("OnlyStartOrStopDynamic", [
              IndexSpec(shape=(5, 4), indexer=(slice(None, 3), slice(0, 2))),
              IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None)))
          ]),
      ]
      for shape, indexer in index_specs
      for dtype in all_dtypes
      for rng_factory in [jtu.rand_default])
  def testDynamicIndexingWithSlicesErrors(self, shape, dtype, rng_factory, indexer):
    rng = rng_factory()
    unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)

    @api.jit
    def fun(x, unpacked_indexer):
      indexer = pack_indexer(unpacked_indexer)
      return x[indexer]

    args_maker = lambda: [rng(shape, dtype), unpacked_indexer]
    self.assertRaises(IndexError, lambda: fun(*args_maker()))

  @parameterized.named_parameters(
      {"testcase_name": "{}_inshape={}_indexer={}"
       .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
       "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer}
      for name, index_specs in [
          ("OneIntIndex",
           [IndexSpec(shape=(3,), indexer=1),
            IndexSpec(shape=(3, 3), indexer=0),
            IndexSpec(shape=(3, 4, 5), indexer=2),
            IndexSpec(shape=(3,), indexer=-1),
            IndexSpec(shape=(3,), indexer=-2)]),
          ("TwoIntIndices",
           [IndexSpec(shape=(3, 3), indexer=(2, 1)),
            IndexSpec(shape=(3, 4, 5), indexer=(1, 2)),
            IndexSpec(shape=(3, 4, 5), indexer=(-1, 2))]),
          ("ThreeIntIndices",
           [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]),
      ]
      for shape, indexer in index_specs
      for dtype in all_dtypes
      for rng_factory in [jtu.rand_default])
  def testDynamicIndexingWithIntegers(self, shape, dtype, rng_factory, indexer):
    rng = rng_factory()
    unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)

    def fun(x, unpacked_indexer):
      indexer = pack_indexer(unpacked_indexer)
      return x[indexer]

    args_maker = lambda: [rng(shape, dtype), unpacked_indexer]
    self._CompileAndCheck(fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "{}_inshape={}_indexer={}"
       .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
       "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer}
      for name, index_specs in [
          ("OneIntIndex",
           [IndexSpec(shape=(3,), indexer=1),
            IndexSpec(shape=(3, 3), indexer=0),
            IndexSpec(shape=(3, 4, 5), indexer=2),
            IndexSpec(shape=(3,), indexer=-1),
            IndexSpec(shape=(3,), indexer=-2),
            ]),
          ("TwoIntIndices",
           [IndexSpec(shape=(3, 3), indexer=(2, 1)),
            IndexSpec(shape=(3, 4, 5), indexer=(1, 2)),
            IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)),
            ]),
          ("ThreeIntIndices",
           [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]),
      ]
      for shape, indexer in index_specs
      for dtype in float_dtypes
      for rng_factory in [jtu.rand_default])
  def testDynamicIndexingWithIntegersGrads(self, shape, dtype, rng_factory, indexer):
    rng = rng_factory()
    tol = 1e-2 if lnp.finfo(dtype).bits == 32 else None
    unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)

    @api.jit
    def fun(unpacked_indexer, x):
      indexer = pack_indexer(unpacked_indexer)
      return x[indexer]

    arr = rng(shape, dtype)
    check_grads(partial(fun, unpacked_indexer), (arr,), 2, tol, tol, tol)

  @parameterized.named_parameters(
      {"testcase_name": "{}_inshape={}_indexer={}"
       .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
       "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer}
      for name, index_specs in ADVANCED_INDEXING_TESTS
      for shape, indexer in index_specs
      for dtype in all_dtypes
      for rng_factory in [jtu.rand_default])
  def testAdvancedIntegerIndexing(self, shape, dtype, rng_factory, indexer):
    rng = rng_factory()
    args_maker = lambda: [rng(shape, dtype), indexer]
    fun = lambda x, idx: x[idx]
    self._CompileAndCheck(fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "{}_inshape={}_indexer={}"
       .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
       "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer}
      for name, index_specs in [
          ("One1DIntArrayIndex",
           [IndexSpec(shape=(3,), indexer=onp.array([0, 1])),
            IndexSpec(shape=(3, 3), indexer=onp.array([1, 2, 1])),
            IndexSpec(shape=(3, 4, 5), indexer=onp.array([0, 2, 0, 1])),
            IndexSpec(shape=(3,), indexer=onp.array([-1, 1])),
            IndexSpec(shape=(3,), indexer=onp.array([-2, -1])),
            ]),
          ("One2DIntArrayIndex",
           [IndexSpec(shape=(3,), indexer=onp.array([[0, 0]])),
            IndexSpec(shape=(3, 3), indexer=onp.array([[1, 2, 1],
                                                       [0, 1, -1]])),
            IndexSpec(shape=(3, 4, 5), indexer=onp.array([[0, 2, 0, 1],
                                                          [-1, -2, 1, 0]])),
            ]),
          ("Two1DIntArrayIndicesNoBroadcasting",
           [IndexSpec(shape=(3, 3), indexer=[onp.array([0, 1]),
                                             onp.array([1, 2])]),
            IndexSpec(shape=(3, 4, 5), indexer=[onp.array([0, 2, 0, 1]),
                                                onp.array([-1, 0, -1, 2])]),
            ]),
          ("Two1DIntArrayIndicesWithBroadcasting",
           [IndexSpec(shape=(3, 3), indexer=[onp.array([[0, 1]]),
                                             onp.array([1, 2])]),
            IndexSpec(shape=(3, 4, 5), indexer=[onp.array([[0, 2, 0, 1]]),
                                                onp.array([-1, 0, -1, 2])]),
            ]),
          ("ListOfPythonInts",
           [IndexSpec(shape=(3,), indexer=[0, 1, 0]),
            IndexSpec(shape=(3, 4, 5), indexer=[0, -1]),
            ]),
          ("ListOfListsOfPythonInts",
           [IndexSpec(shape=(3, 4, 5), indexer=[[0, 1]]),
            IndexSpec(shape=(3, 4, 5), indexer=[[[0], [-1]], [[2, 3, 0, 3]]]),
            ]),
          ("ListOfPythonIntsAndIntArrays",
           [IndexSpec(shape=(3, 4, 5), indexer=[0, onp.array([0, 1])]),
            IndexSpec(shape=(3, 4, 5), indexer=[0, 1,
                                                onp.array([[2, 3, 0, 3]])]),
            ]),
          ("ListOfListsOfPythonIntsAndIntArrays",
           [IndexSpec(shape=(3, 4, 5), indexer=[[0, 1], onp.array([0])]),
            IndexSpec(shape=(3, 4, 5), indexer=[[[0], [-1]],
                                                onp.array([[2, 3, 0, 3]])]),
            ]),
      ]
      for shape, indexer in index_specs
      for dtype in float_dtypes
      for rng_factory in [jtu.rand_default])
  def testAdvancedIntegerIndexingGrads(self, shape, dtype, rng_factory, indexer):
    rng = rng_factory()
    tol = 1e-2 if lnp.finfo(dtype).bits == 32 else None
    arg = rng(shape, dtype)
    fun = lambda x: x[indexer]**2
    check_grads(fun, (arg,), 2, tol, tol, tol)

  @parameterized.named_parameters(
      {"testcase_name": "{}_inshape={}_indexer={}"
       .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
       "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer}
      for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS
      for shape, indexer in index_specs
      for dtype in all_dtypes
      for rng_factory in [jtu.rand_default])
  def testMixedAdvancedIntegerIndexing(self, shape, dtype, rng_factory, indexer):
    rng = rng_factory()
    indexer_with_dummies = [e if isinstance(e, onp.ndarray) else ()
                            for e in indexer]
    substitutes = [(i, e) for i, e in enumerate(indexer)
                   if not isinstance(e, onp.ndarray)]
    args_maker = lambda: [rng(shape, dtype), indexer_with_dummies]

    def fun(x, indexer_with_dummies):
      idx = type(indexer)(util.subvals(indexer_with_dummies, substitutes))
      return x[idx]

    self._CompileAndCheck(fun, args_maker, check_dtypes=True)

  def testAdvancedIndexingManually(self):
    x = onp.random.RandomState(0).randn(3, 4, 5)
    index_array = onp.array([0, 2, -1, 0])

    op = lambda x, index_array: x[..., index_array, :]
    cop = api.jit(op)

    a1 = op(x, index_array)
    a2 = cop(x, index_array)

    self.assertAllClose(a1, a2, check_dtypes=True)

    op = lambda x, index_array: x[..., index_array, :, index_array, None]
    cop = api.jit(op)

    a1 = op(x, index_array)
    a2 = cop(x, index_array)

    self.assertAllClose(a1, a2, check_dtypes=True)

    op = lambda x, index_array: x[index_array, ..., index_array[:, None], None]
    cop = api.jit(op)

    a1 = op(x, index_array)
    a2 = cop(x, index_array)

    self.assertAllClose(a1, a2, check_dtypes=True)

  def testUnpacking(self):

    def foo(x):
      a, b, c = x
      return a + b + c

    cfoo = api.jit(foo)

    a1 = foo(onp.arange(3))
    a2 = cfoo(onp.arange(3))

    self.assertAllClose(a1, a2, check_dtypes=True)

  def testBooleanIndexingArray1D(self):
    idx = onp.array([True, True, False])
    x = api.device_put(onp.arange(3))
    ans = x[idx]
    expected = onp.arange(3)[idx]
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testBooleanIndexingList1D(self):
    idx = [True, True, False]
    x = api.device_put(onp.arange(3))
    ans = x[idx]
    expected = onp.arange(3)[idx]
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testBooleanIndexingArray2DBroadcast(self):
    idx = onp.array([True, True, False, True])
    x = onp.arange(8).reshape(4, 2)
    ans = api.device_put(x)[idx]
    expected = x[idx]
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testBooleanIndexingList2DBroadcast(self):
    idx = [True, True, False, True]
    x = onp.arange(8).reshape(4, 2)
    ans = api.device_put(x)[idx]
    expected = x[idx]
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testBooleanIndexingArray2D(self):
    idx = onp.array([[True, False],
                     [False, True],
                     [False, False],
                     [True, True]])
    x = onp.arange(8).reshape(4, 2)
    ans = api.device_put(x)[idx]
    expected = x[idx]
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testBooleanIndexingDynamicShapeError(self):
    x = onp.zeros(3)
    i = onp.array([True, True, False])
    self.assertRaises(IndexError, lambda: api.jit(lambda x, i: x[i])(x, i))

  def testIssue187(self):
    x = lnp.ones((5, 5))
    x[[0, 2, 4], [0, 2, 4]]  # doesn't crash

    x = onp.arange(25).reshape((5, 5))
    ans = api.jit(lambda x: x[[0, 2, 4], [0, 2, 4]])(x)
    expected = x[[0, 2, 4], [0, 2, 4]]
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testJVPOfGradOfIndexing(self):
    # Should return a value, even though we didn't pass a symbolic zero as the
    # index tangent.
    x = lnp.ones((3, 4), lnp.float32)
    i = lnp.ones((3,), lnp.int32)
    f = lambda x, i: lnp.sum(x[i])
    primals, tangents = api.jvp(api.grad(f), (x, i), (x, onp.zeros_like(i)))
    expected = onp.broadcast_to(
      onp.array([0, 3, 0], dtype=onp.float32)[:, None], (3, 4))
    self.assertAllClose(expected, primals, check_dtypes=True)
    self.assertAllClose(onp.zeros_like(x), tangents, check_dtypes=True)

  def testTrivialGatherIsntGenerated(self):
    # https://github.com/google/jax/issues/1621
    jaxpr = api.make_jaxpr(lambda x: x[:, None])(onp.arange(4))
    self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
    self.assertNotIn('gather', str(jaxpr))

  def testBooleanIndexingWithEmptyResult(self):
    # based on a TensorFlow Probability test that started failing after #1622
    x = lnp.array([-1])
    mask = lnp.array([False])
    ans = x[mask]  # doesn't crash

    expected =  onp.array([-1])[onp.array([False])]
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testFloatIndexingError(self):
    x = lnp.array([1, 2, 3])
    self.assertRaises(TypeError, lambda: x[3.5])
Ejemplo n.º 9
0
class IndexedUpdateTest(jtu.JaxTestCase):

  @parameterized.named_parameters(jtu.cases_from_list({
      "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
          name, jtu.format_shape_dtype_string(shape, dtype), indexer,
          jtu.format_shape_dtype_string(update_shape, update_dtype), op.name),
       "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer,
       "update_shape": update_shape, "update_dtype": update_dtype,
       "op": op
  } for name, index_specs in STATIC_INDEXING_TESTS
    for shape, indexer in index_specs
    for op in UpdateOps
    for dtype in (all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
    for update_shape in _broadcastable_shapes(_update_shape(shape, indexer))
    for update_dtype in ([dtype] if op == UpdateOps.ADD else all_dtypes)
    for rng_factory in [jtu.rand_default]))
  def testStaticIndexing(self, shape, dtype, update_shape, update_dtype,
                         rng_factory, indexer, op):
    rng = rng_factory()
    args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
    onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y)
    jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
    self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True)
    self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list({
      "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
          name, jtu.format_shape_dtype_string(shape, dtype), indexer,
          jtu.format_shape_dtype_string(update_shape, update_dtype), op.name),
       "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer,
       "update_shape": update_shape, "update_dtype": update_dtype,
       "op": op
  } for name, index_specs in ADVANCED_INDEXING_TESTS_NO_REPEATS
    for shape, indexer in index_specs
    for op in UpdateOps
    for dtype in (all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
    for update_shape in _broadcastable_shapes(_update_shape(shape, indexer))
    for update_dtype in ([dtype] if op == UpdateOps.ADD else all_dtypes)
    for rng_factory in [jtu.rand_default]))
  def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
                           rng_factory, indexer, op):
    rng = rng_factory()
    args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
    onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y)
    jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
    self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True)
    self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list({
      "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
          name, jtu.format_shape_dtype_string(shape, dtype), indexer,
          jtu.format_shape_dtype_string(update_shape, update_dtype), op.name),
       "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer,
       "update_shape": update_shape, "update_dtype": update_dtype,
       "op": op
  } for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS
    for shape, indexer in index_specs
    for op in UpdateOps
    for dtype in (all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
    for update_shape in _broadcastable_shapes(_update_shape(shape, indexer))
    for update_dtype in ([dtype] if op == UpdateOps.ADD else all_dtypes)
    for rng_factory in [jtu.rand_default]))
  def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
                           rng_factory, indexer, op):
    rng = rng_factory()
    args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)]
    onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y)
    jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
    self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True)
    self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list({
      "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format(
          name, jtu.format_shape_dtype_string(shape, dtype), indexer,
          jtu.format_shape_dtype_string(update_shape, update_dtype), op.name),
       "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer,
       "update_shape": update_shape, "update_dtype": update_dtype,
       "op": op
  } for name, index_specs in STATIC_INDEXING_TESTS
    for shape, indexer in index_specs
    for op in UpdateOps
    for dtype in float_dtypes
    for update_shape in _broadcastable_shapes(_update_shape(shape, indexer))
    for update_dtype in ([dtype] if op == UpdateOps.ADD else float_dtypes)
    for rng_factory in [jtu.rand_default]))
  @jtu.skip_on_devices("tpu")  # TODO(mattjj,phawkins): tpu issues
  def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype,
                              rng_factory, indexer, op):
    rng = rng_factory()
    jax_op = ops.index_update if op == UpdateOps.UPDATE else ops.index_add
    jax_fn = lambda x, y: jax_op(x, indexer, y)
    x = rng(shape, dtype)
    y = rng(update_shape, update_dtype)
    check_grads(jax_fn, (x, y), 2, rtol=1e-3, atol=1e-3, eps=1.)

  def testSegmentSumBehavior(self):
    # testAdvancedIndexing compares against NumPy, and as a result doesn't check
    # repeated indices. This test is just a simple manual check, based on
    # https://www.tensorflow.org/api_docs/python/tf/math/segment_sum
    data = onp.array([5, 1, 7, 2, 3, 4, 1, 3])
    segment_ids = onp.array([0, 0, 0, 1, 2, 2, 3, 3])

    ans = ops.index_add(onp.zeros(onp.max(segment_ids) + 1), segment_ids, data)
    expected = onp.array([13, 2, 7, 4])
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testSegmentSum(self):
    data = onp.array([5, 1, 7, 2, 3, 4, 1, 3])
    segment_ids = onp.array([0, 0, 0, 1, 2, 2, 3, 3])

    # test with explicit num_segments
    ans = ops.segment_sum(data, segment_ids, num_segments=4)
    expected = onp.array([13, 2, 7, 4])
    self.assertAllClose(ans, expected, check_dtypes=False)

    # test without explicit num_segments
    ans = ops.segment_sum(data, segment_ids)
    expected = onp.array([13, 2, 7, 4])
    self.assertAllClose(ans, expected, check_dtypes=False)
Ejemplo n.º 10
0
class SimulateTest(jtu.JaxTestCase):

    # pylint: disable=g-complex-comprehension
    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype
            } for dim in SPATIAL_DIMENSION for dtype in DTYPE))
    def test_nve_ensemble(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)
        pos_key, center_key, vel_key, mass_key = random.split(key, 4)
        R = random.normal(pos_key, (PARTICLE_COUNT, spatial_dimension),
                          dtype=dtype)
        R0 = random.normal(center_key, (PARTICLE_COUNT, spatial_dimension),
                           dtype=dtype)
        mass = random.uniform(mass_key, (PARTICLE_COUNT, ),
                              minval=0.1,
                              maxval=5.0,
                              dtype=dtype)
        _, shift = space.free()

        E = lambda R, **kwargs: np.sum((R - R0)**2)

        init_fn, apply_fn = simulate.nve(E, shift, 1e-3)
        apply_fn = jit(apply_fn)

        state = init_fn(vel_key, R, kT=0.5, mass=mass)

        E_T = lambda state: \
            E(state.position) + quantity.kinetic_energy(state.velocity, state.mass)
        E_initial = E_T(state)

        for _ in range(DYNAMICS_STEPS):
            state = apply_fn(state)
            E_total = E_T(state)
            assert np.abs(E_total - E_initial) < E_initial * 0.01
            assert state.position.dtype == dtype

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype
            } for dim in SPATIAL_DIMENSION for dtype in DTYPE))
    def test_nve_jammed(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        state = test_util.load_test_state('simulation_test_state.npy', dtype)
        displacement_fn, shift_fn = space.periodic(state.box[0, 0])

        E = energy.soft_sphere_pair(displacement_fn, state.species,
                                    state.sigma)

        init_fn, apply_fn = simulate.nve(E, shift_fn, 1e-3)
        apply_fn = jit(apply_fn)

        state = init_fn(key, state.real_position, kT=1e-3)

        E_T = lambda state: \
            E(state.position) + quantity.kinetic_energy(state.velocity, state.mass)
        E_initial = E_T(state) * np.ones((DYNAMICS_STEPS, ))

        def step_fn(i, state_and_energy):
            state, energy = state_and_energy
            state = apply_fn(state)
            energy = ops.index_update(energy, i, E_T(state))
            return state, energy

        Es = np.zeros((DYNAMICS_STEPS, ))
        state, Es = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es))

        tol = 1e-3 if dtype is f32 else 1e-7
        self.assertEqual(state.position.dtype, dtype)
        self.assertAllClose(Es, E_initial, rtol=tol, atol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype
            } for dim in SPATIAL_DIMENSION for dtype in DTYPE))
    def test_nve_jammed(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        state = test_util.load_test_state('simulation_test_state.npy', dtype)
        displacement_fn, shift_fn = space.periodic(state.box[0, 0])

        E = energy.soft_sphere_pair(displacement_fn, state.species,
                                    state.sigma)

        init_fn, apply_fn = simulate.nve(E, shift_fn, 1e-3)
        apply_fn = jit(apply_fn)

        state = init_fn(key, state.real_position, kT=1e-3)

        E_T = lambda state: \
            E(state.position) + quantity.kinetic_energy(state.velocity, state.mass)
        E_initial = E_T(state) * np.ones((DYNAMICS_STEPS, ))

        def step_fn(i, state_and_energy):
            state, energy = state_and_energy
            state = apply_fn(state)
            energy = ops.index_update(energy, i, E_T(state))
            return state, energy

        Es = np.zeros((DYNAMICS_STEPS, ))
        state, Es = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es))

        tol = 1e-3 if dtype is f32 else 1e-7
        self.assertEqual(state.position.dtype, dtype)
        self.assertAllClose(Es, E_initial, rtol=tol, atol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name': f'_dtype={dtype.__name__}_coordinates={coords}',
            'dtype': dtype,
            'coords': coords
        } for dtype in DTYPE for coords in COORDS))
    def test_nve_jammed_periodic_general(self, dtype, coords):
        key = random.PRNGKey(0)

        state = test_util.load_test_state('simulation_test_state.npy', dtype)
        displacement_fn, shift_fn = space.periodic_general(
            state.box, coords == 'fractional')

        E = energy.soft_sphere_pair(displacement_fn, state.species,
                                    state.sigma)

        init_fn, apply_fn = simulate.nve(E, shift_fn, 1e-3)
        apply_fn = jit(apply_fn)

        state = init_fn(key, getattr(state, coords + '_position'), kT=1e-3)

        E_T = lambda state: \
            E(state.position) + quantity.kinetic_energy(state.velocity, state.mass)
        E_initial = E_T(state) * np.ones((DYNAMICS_STEPS, ))

        def step_fn(i, state_and_energy):
            state, energy = state_and_energy
            state = apply_fn(state)
            energy = ops.index_update(energy, i, E_T(state))
            return state, energy

        Es = np.zeros((DYNAMICS_STEPS, ))
        state, Es = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es))

        tol = 1e-3 if dtype is f32 else 1e-7
        self.assertEqual(state.position.dtype, dtype)
        self.assertAllClose(Es, E_initial, rtol=tol, atol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype
            } for dim in SPATIAL_DIMENSION for dtype in DTYPE))
    def test_nve_neighbor_list(self, spatial_dimension, dtype):
        Nx = particles_per_side = 8
        spacing = f32(1.25)

        tol = 5e-12 if dtype == np.float64 else 5e-3

        L = Nx * spacing
        if spatial_dimension == 2:
            R = np.stack([np.array(r) for r in onp.ndindex(Nx, Nx)]) * spacing
        elif spatial_dimension == 3:
            R = np.stack([np.array(r)
                          for r in onp.ndindex(Nx, Nx, Nx)]) * spacing

        R = np.array(R, dtype)

        displacement, shift = space.periodic(L)

        neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(
            displacement, L)
        exact_energy_fn = energy.lennard_jones_pair(displacement)

        init_fn, apply_fn = simulate.nve(energy_fn, shift, 1e-3)
        exact_init_fn, exact_apply_fn = simulate.nve(exact_energy_fn, shift,
                                                     1e-3)

        nbrs = neighbor_fn(R)
        state = init_fn(random.PRNGKey(0), R, kT=0.5, neighbor=nbrs)
        exact_state = exact_init_fn(random.PRNGKey(0), R, kT=0.5)

        def body_fn(i, state):
            state, nbrs, exact_state = state
            nbrs = neighbor_fn(state.position, nbrs)
            state = apply_fn(state, neighbor=nbrs)
            return state, nbrs, exact_apply_fn(exact_state)

        step = 0
        for i in range(20):
            new_state, nbrs, new_exact_state = lax.fori_loop(
                0, 100, body_fn, (state, nbrs, exact_state))
            if nbrs.did_buffer_overflow:
                nbrs = neighbor_fn(state.position)
            else:
                state = new_state
                exact_state = new_exact_state
                step += 1
        assert state.position.dtype == dtype
        self.assertAllClose(state.position,
                            exact_state.position,
                            atol=tol,
                            rtol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name':
            f'_dim={dim}_dtype={dtype.__name__}_sy_steps={sy_steps}',
            'spatial_dimension': dim,
            'dtype': dtype,
            'sy_steps': sy_steps,
        } for dim in SPATIAL_DIMENSION for dtype in DTYPE
                            for sy_steps in [1, 3, 5, 7]))
    def test_nvt_nose_hoover(self, spatial_dimension, dtype, sy_steps):
        key = random.PRNGKey(0)

        box_size = quantity.box_size_at_number_density(PARTICLE_COUNT,
                                                       f32(1.2),
                                                       spatial_dimension)
        displacement_fn, shift_fn = space.periodic(box_size)

        bonds_i = np.arange(PARTICLE_COUNT)
        bonds_j = np.roll(bonds_i, 1)
        bonds = np.stack([bonds_i, bonds_j])

        E = energy.simple_spring_bond(displacement_fn, bonds)

        invariant = partial(simulate.nvt_nose_hoover_invariant, E)

        for _ in range(STOCHASTIC_SAMPLES):
            key, pos_key, vel_key, T_key, masses_key = random.split(key, 5)

            R = box_size * random.uniform(pos_key,
                                          (PARTICLE_COUNT, spatial_dimension),
                                          dtype=dtype)
            T = random.uniform(T_key, (), minval=0.3, maxval=1.4, dtype=dtype)
            mass = 1 + random.uniform(masses_key, (PARTICLE_COUNT, ),
                                      dtype=dtype)
            init_fn, apply_fn = simulate.nvt_nose_hoover(E,
                                                         shift_fn,
                                                         1e-3,
                                                         T,
                                                         sy_steps=sy_steps)
            apply_fn = jit(apply_fn)

            state = init_fn(vel_key, R, mass=mass)

            initial = invariant(state, T)

            for _ in range(DYNAMICS_STEPS):
                state = apply_fn(state)

            T_final = quantity.temperature(state.velocity, state.mass)
            assert np.abs(T_final - T) / T < 0.1
            self.assertAllClose(invariant(state, T), initial, rtol=1e-4)
            self.assertEqual(state.position.dtype, dtype)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name': f'dtype={dtype.__name__}_sy_steps={sy_steps}',
            'dtype': dtype,
            'sy_steps': sy_steps,
        } for dtype in DTYPE for sy_steps in [1, 3, 5, 7]))
    def test_nvt_nose_hoover_jammed(self, dtype, sy_steps):
        key = random.PRNGKey(0)

        state = test_util.load_test_state('simulation_test_state.npy', dtype)
        displacement_fn, shift_fn = space.periodic(state.box[0, 0])

        E = energy.soft_sphere_pair(displacement_fn, state.species,
                                    state.sigma)
        invariant = partial(simulate.nvt_nose_hoover_invariant, E)

        kT = 1e-3
        init_fn, apply_fn = simulate.nvt_nose_hoover(E,
                                                     shift_fn,
                                                     1e-3,
                                                     kT=kT,
                                                     sy_steps=sy_steps)
        apply_fn = jit(apply_fn)

        state = init_fn(key, state.real_position)

        E_initial = invariant(state, kT) * np.ones((DYNAMICS_STEPS, ))

        def step_fn(i, state_and_energy):
            state, energy = state_and_energy
            state = apply_fn(state)
            energy = ops.index_update(energy, i, invariant(state, kT))
            return state, energy

        Es = np.zeros((DYNAMICS_STEPS, ))
        state, Es = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es))

        tol = 1e-3 if dtype is f32 else 1e-7
        self.assertEqual(state.position.dtype, dtype)
        self.assertAllClose(Es, E_initial, rtol=tol, atol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name': f'dtype={dtype.__name__}_sy_steps={sy_steps}',
            'dtype': dtype,
            'sy_steps': sy_steps,
        } for dtype in DTYPE for sy_steps in [1, 3, 5, 7]))
    def test_npt_nose_hoover_jammed(self, dtype, sy_steps):
        key = random.PRNGKey(0)

        state = test_util.load_test_state('simulation_test_state.npy', dtype)
        displacement_fn, shift_fn = space.periodic_general(state.box)

        E = energy.soft_sphere_pair(displacement_fn, state.species,
                                    state.sigma)
        invariant = partial(simulate.npt_nose_hoover_invariant, E)
        pressure_fn = partial(quantity.pressure, E)

        nhc_kwargs = {sy_steps: sy_steps}
        kT = 1e-3
        P = state.pressure
        init_fn, apply_fn = simulate.npt_nose_hoover(E, shift_fn, 1e-3, P, kT,
                                                     nhc_kwargs, nhc_kwargs)
        apply_fn = jit(apply_fn)

        state = init_fn(key, state.fractional_position, state.box)

        E_initial = invariant(state, P, kT) * np.ones((DYNAMICS_STEPS, ))
        P_target = P * np.ones((DYNAMICS_STEPS, ))

        def step_fn(i, state_energy_pressure):
            state, energy, pressure = state_energy_pressure
            state = apply_fn(state)
            energy = ops.index_update(energy, i, invariant(state, P, kT))
            box = simulate.npt_box(state)
            KE = quantity.kinetic_energy(state.velocity, state.mass)
            p = pressure_fn(state.position, box, KE)
            pressure = ops.index_update(pressure, i, p)
            return state, energy, pressure

        Es = np.zeros((DYNAMICS_STEPS, ))
        Ps = np.zeros((DYNAMICS_STEPS, ))
        state, Es, Ps = lax.fori_loop(0, DYNAMICS_STEPS, step_fn,
                                      (state, Es, Ps))

        tol = 1e-3 if dtype is f32 else 1e-7
        self.assertEqual(state.position.dtype, dtype)
        self.assertAllClose(Es, E_initial, rtol=tol, atol=tol)
        self.assertAllClose(Ps, P_target, rtol=0.05, atol=0.05)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype
            } for dim in SPATIAL_DIMENSION for dtype in DTYPE))
    def test_nvt_langevin(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, R_key, R0_key, T_key, masses_key = random.split(key, 5)

            R = random.normal(R_key,
                              (LANGEVIN_PARTICLE_COUNT, spatial_dimension),
                              dtype=dtype)
            R0 = random.normal(R0_key,
                               (LANGEVIN_PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            _, shift = space.free()

            E = functools.partial(lambda R, R0, **kwargs: np.sum((R - R0)**2),
                                  R0=R0)

            T = random.uniform(T_key, (), minval=0.3, maxval=1.4, dtype=dtype)
            mass = random.uniform(masses_key, (LANGEVIN_PARTICLE_COUNT, ),
                                  minval=0.1,
                                  maxval=10.0,
                                  dtype=dtype)
            init_fn, apply_fn = simulate.nvt_langevin(E,
                                                      shift,
                                                      f32(1e-2),
                                                      T,
                                                      gamma=f32(0.3))
            apply_fn = jit(apply_fn)

            state = init_fn(key, R, mass=mass, T_initial=dtype(1.0))

            T_list = []
            for step in range(LANGEVIN_DYNAMICS_STEPS):
                state = apply_fn(state)
                if step > 4000 and step % 100 == 0:
                    T_list += [
                        quantity.temperature(state.velocity, state.mass)
                    ]

            # TODO(schsam): It would be good to check Gaussinity of R and V in the
            # noninteracting case.
            T_emp = np.mean(np.array(T_list))
            assert np.abs(T_emp - T) < 0.1
            assert state.position.dtype == dtype

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype
            } for dim in SPATIAL_DIMENSION for dtype in DTYPE))
    def test_brownian(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)
        key, T_split, mass_split = random.split(key, 3)

        _, shift = space.free()
        energy_fn = lambda R, **kwargs: f32(0)

        R = np.zeros((BROWNIAN_PARTICLE_COUNT, 2), dtype=dtype)
        mass = random.uniform(mass_split, (),
                              minval=0.1,
                              maxval=10.0,
                              dtype=dtype)
        T = random.uniform(T_split, (), minval=0.3, maxval=1.4, dtype=dtype)

        dt = f32(1e-2)
        gamma = f32(0.1)

        init_fn, apply_fn = simulate.brownian(energy_fn,
                                              shift,
                                              dt,
                                              T,
                                              gamma=gamma)
        apply_fn = jit(apply_fn)

        state = init_fn(key, R, mass)

        sim_t = f32(BROWNIAN_DYNAMICS_STEPS * dt)
        for _ in range(BROWNIAN_DYNAMICS_STEPS):
            state = apply_fn(state)

        msd = np.var(state.position)
        th_msd = dtype(2 * T / (mass * gamma) * sim_t)
        assert np.abs(msd - th_msd) / msd < 1e-2
        assert state.position.dtype == dtype
Ejemplo n.º 11
0
class FftTest(jtu.JaxTestCase):

  def testNotImplemented(self):
    for name in jnp.fft._NOT_IMPLEMENTED:
      func = getattr(jnp.fft, name)
      with self.assertRaises(NotImplementedError):
        func()

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inverse={}_real={}_shape={}_axes={}".format(
          inverse, real, jtu.format_shape_dtype_string(shape, dtype), axes),
       "axes": axes, "shape": shape, "dtype": dtype,
       "rng_factory": rng_factory, "inverse": inverse, "real": real}
      for inverse in [False, True]
      for real in [False, True]
      for rng_factory in [jtu.rand_default]
      for dtype in (real_dtypes if real and not inverse else all_dtypes)
      for shape in [(10,), (10, 10), (9,), (2, 3, 4), (2, 3, 4, 5)]
      for axes in _get_fftn_test_axes(shape)))
  def testFftn(self, inverse, real, shape, dtype, axes, rng_factory):
    rng = rng_factory(self.rng())
    args_maker = lambda: (rng(shape, dtype),)
    jnp_op = _get_fftn_func(jnp.fft, inverse, real)
    np_op = _get_fftn_func(np.fft, inverse, real)
    jnp_fn = lambda a: jnp_op(a, axes=axes)
    np_fn = lambda a: np_op(a, axes=axes) if axes is None or axes else a
    # Numpy promotes to complex128 aggressively.
    self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
                            tol=1e-4)
    self._CompileAndCheck(jnp_fn, args_maker)
    # Test gradient for differentiable types.
    if (FLAGS.jax_enable_x64 and
        dtype in (float_dtypes if real and not inverse else inexact_dtypes)):
      # TODO(skye): can we be more precise?
      tol = 0.15
      jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inverse={}_real={}".format(inverse, real),
       "inverse": inverse, "real": real}
      for inverse in [False, True]
      for real in [False, True]))
  def testFftnErrors(self, inverse, real):
    rng = jtu.rand_default(self.rng())
    name = 'fftn'
    if real:
      name = 'r' + name
    if inverse:
      name = 'i' + name
    func = _get_fftn_func(jnp.fft, inverse, real)
    self.assertRaisesRegex(
        ValueError,
        "jax.numpy.fft.{} only supports 1D, 2D, and 3D FFTs. "
        "Got axes None with input rank 4.".format(name),
        lambda: func(rng([2, 3, 4, 5], dtype=np.float64), axes=None))
    self.assertRaisesRegex(
        ValueError,
        "jax.numpy.fft.{} does not support repeated axes. Got axes \\[1, 1\\].".format(name),
        lambda: func(rng([2, 3], dtype=np.float64), axes=[1, 1]))
    self.assertRaises(
        ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[2]))
    self.assertRaises(
        ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[-3]))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inverse={}_real={}_hermitian={}_shape={}_axis={}".format(
          inverse, real, hermitian, jtu.format_shape_dtype_string(shape, dtype), axis),
       "axis": axis, "shape": shape, "dtype": dtype,
       "rng_factory": rng_factory, "inverse": inverse, "real": real,
       "hermitian": hermitian}
      for inverse in [False, True]
      for real in [False, True]
      for hermitian in [False, True]
      for rng_factory in [jtu.rand_default]
      for dtype in (real_dtypes if (real and not inverse) or (hermitian and inverse)
                                else all_dtypes)
      for shape in [(10,)]
      for axis in [-1, 0]))
  def testFft(self, inverse, real, hermitian, shape, dtype, axis, rng_factory):
    rng = rng_factory(self.rng())
    args_maker = lambda: (rng(shape, dtype),)
    name = 'fft'
    if real:
      name = 'r' + name
    elif hermitian:
      name = 'h' + name
    if inverse:
      name = 'i' + name
    jnp_op = getattr(jnp.fft, name)
    np_op = getattr(np.fft, name)
    jnp_fn = lambda a: jnp_op(a, axis=axis)
    np_fn = lambda a: np_op(a, axis=axis)
    # Numpy promotes to complex128 aggressively.
    self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
                            tol=1e-4)
    self._CompileAndCheck(jnp_op, args_maker)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inverse={}_real={}_hermitian={}".format(inverse, real, hermitian),
       "inverse": inverse, "real": real, "hermitian": hermitian}
      for inverse in [False, True]
      for real in [False, True]
      for hermitian in [False, True]))
  def testFftErrors(self, inverse, real, hermitian):
    rng = jtu.rand_default(self.rng())
    name = 'fft'
    if real:
      name = 'r' + name
    elif hermitian:
      name = 'h' + name
    if inverse:
      name = 'i' + name
    func = getattr(jnp.fft, name)

    self.assertRaisesRegex(
      ValueError,
      "jax.numpy.fft.{} does not support multiple axes. "
      "Please use jax.numpy.fft.{}n. "
      "Got axis = \\[1, 1\\].".format(name, name),
      lambda: func(rng([2, 3], dtype=np.float64), axis=[1, 1])
    )
    self.assertRaisesRegex(
      ValueError,
      "jax.numpy.fft.{} does not support multiple axes. "
      "Please use jax.numpy.fft.{}n. "
      "Got axis = \\(1, 1\\).".format(name, name),
      lambda: func(rng([2, 3], dtype=np.float64), axis=(1, 1))
    )
    self.assertRaises(
        ValueError, lambda: func(rng([2, 3], dtype=np.float64), axis=[2]))
    self.assertRaises(
        ValueError, lambda: func(rng([2, 3], dtype=np.float64), axis=[-3]))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inverse={}_real={}_shape={}_axes={}".format(
          inverse, real, jtu.format_shape_dtype_string(shape, dtype), axes),
       "axes": axes, "shape": shape, "dtype": dtype,
       "rng_factory": rng_factory, "inverse": inverse, "real": real}
      for inverse in [False, True]
      for real in [False, True]
      for rng_factory in [jtu.rand_default]
      for dtype in (real_dtypes if real and not inverse else all_dtypes)
      for shape in [(16, 8, 4, 8), (16, 8, 4, 8, 4)]
      for axes in [(-2, -1), (0, 1), (1, 3), (-1, 2)]))
  def testFft2(self, inverse, real, shape, dtype, axes, rng_factory):
    rng = rng_factory(self.rng())
    args_maker = lambda: (rng(shape, dtype),)
    name = 'fft2'
    if real:
      name = 'r' + name
    if inverse:
      name = 'i' + name
    jnp_op = getattr(jnp.fft, name)
    np_op = getattr(np.fft, name)
    # Numpy promotes to complex128 aggressively.
    self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=False,
                            tol=1e-4)
    self._CompileAndCheck(jnp_op, args_maker)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inverse={}_real={}".format(inverse, real),
       "inverse": inverse, "real": real}
      for inverse in [False, True]
      for real in [False, True]))
  def testFft2Errors(self, inverse, real):
    rng = jtu.rand_default(self.rng())
    name = 'fft2'
    if real:
      name = 'r' + name
    if inverse:
      name = 'i' + name
    func = getattr(jnp.fft, name)

    self.assertRaisesRegex(
      ValueError,
      "jax.numpy.fft.{} only supports 2 axes. "
      "Got axes = \\[0\\].".format(name),
      lambda: func(rng([2, 3], dtype=np.float64), axes=[0])
    )
    self.assertRaisesRegex(
      ValueError,
      "jax.numpy.fft.{} only supports 2 axes. "
      "Got axes = \\(0, 1, 2\\).".format(name),
      lambda: func(rng([2, 3, 3], dtype=np.float64), axes=(0, 1, 2))
    )
    self.assertRaises(
      ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[2, 3]))
    self.assertRaises(
      ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[-3, -4]))

  @parameterized.named_parameters(jtu.cases_from_list(
    {"testcase_name": "_size={}_d={}".format(
      jtu.format_shape_dtype_string([size], dtype), d),
      "dtype": dtype, "size": size, "rng_factory": rng_factory, "d": d}
    for rng_factory in [jtu.rand_default]
    for dtype in all_dtypes
    for size in [9, 10, 101, 102]
    for d in [0.1, 2.]))
  def testFftfreq(self, size, d, dtype, rng_factory):
    rng = rng_factory(self.rng())
    args_maker = lambda: (rng([size], dtype),)
    jnp_op = jnp.fft.fftfreq
    np_op = np.fft.fftfreq
    jnp_fn = lambda a: jnp_op(size, d=d)
    np_fn = lambda a: np_op(size, d=d)
    # Numpy promotes to complex128 aggressively.
    self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
                            tol=1e-4)
    self._CompileAndCheck(jnp_fn, args_maker)
    # Test gradient for differentiable types.
    if dtype in inexact_dtypes:
      tol = 0.15  # TODO(skye): can we be more precise?
      jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)

  @parameterized.named_parameters(jtu.cases_from_list(
    {"testcase_name": "_n={}".format(n),
     "n": n}
    for n in [[0,1,2]]))
  def testFftfreqErrors(self, n):
    name = 'fftfreq'
    func = jnp.fft.fftfreq
    self.assertRaisesRegex(
      ValueError,
      "The n argument of jax.numpy.fft.{} only takes an int. "
      "Got n = \\[0, 1, 2\\].".format(name),
      lambda: func(n=n)
    )
    self.assertRaisesRegex(
      ValueError,
      "The d argument of jax.numpy.fft.{} only takes a single value. "
      "Got d = \\[0, 1, 2\\].".format(name),
      lambda: func(n=10, d=n)
    )

  @parameterized.named_parameters(jtu.cases_from_list(
    {"testcase_name": "_size={}_d={}".format(
      jtu.format_shape_dtype_string([size], dtype), d),
      "dtype": dtype, "size": size, "rng_factory": rng_factory, "d": d}
    for rng_factory in [jtu.rand_default]
    for dtype in all_dtypes
    for size in [9, 10, 101, 102]
    for d in [0.1, 2.]))
  def testRfftfreq(self, size, d, dtype, rng_factory):
    rng = rng_factory(self.rng())
    args_maker = lambda: (rng([size], dtype),)
    jnp_op = jnp.fft.rfftfreq
    np_op = np.fft.rfftfreq
    jnp_fn = lambda a: jnp_op(size, d=d)
    np_fn = lambda a: np_op(size, d=d)
    # Numpy promotes to complex128 aggressively.
    self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
                            tol=1e-4)
    self._CompileAndCheck(jnp_fn, args_maker)
    # Test gradient for differentiable types.
    if dtype in inexact_dtypes:
      tol = 0.15  # TODO(skye): can we be more precise?
      jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)

  @parameterized.named_parameters(jtu.cases_from_list(
    {"testcase_name": "_n={}".format(n),
     "n": n}
    for n in [[0, 1, 2]]))
  def testRfftfreqErrors(self, n):
    name = 'rfftfreq'
    func = jnp.fft.rfftfreq
    self.assertRaisesRegex(
      ValueError,
      "The n argument of jax.numpy.fft.{} only takes an int. "
      "Got n = \\[0, 1, 2\\].".format(name),
      lambda: func(n=n)
    )
    self.assertRaisesRegex(
      ValueError,
      "The d argument of jax.numpy.fft.{} only takes a single value. "
      "Got d = \\[0, 1, 2\\].".format(name),
      lambda: func(n=10, d=n)
    )

  @parameterized.named_parameters(jtu.cases_from_list(
    {"testcase_name": "dtype={}_axes={}".format(
      jtu.format_shape_dtype_string(shape, dtype), axes),
      "dtype": dtype, "shape": shape, "rng_factory": rng_factory, "axes": axes}
    for rng_factory in [jtu.rand_default]
    for dtype in all_dtypes
    for shape in [[9], [10], [101], [102], [3, 5], [3, 17], [5, 7, 11]]
    for axes in _get_fftn_test_axes(shape)))
  def testFftshift(self, shape, dtype, rng_factory, axes):
    rng = rng_factory(self.rng())
    args_maker = lambda: (rng(shape, dtype),)
    jnp_fn = lambda arg: jnp.fft.fftshift(arg, axes=axes)
    np_fn = lambda arg: np.fft.fftshift(arg, axes=axes)
    self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker)

  @parameterized.named_parameters(jtu.cases_from_list(
    {"testcase_name": "dtype={}_axes={}".format(
      jtu.format_shape_dtype_string(shape, dtype), axes),
      "dtype": dtype, "shape": shape, "rng_factory": rng_factory, "axes": axes}
    for rng_factory in [jtu.rand_default]
    for dtype in all_dtypes
    for shape in [[9], [10], [101], [102], [3, 5], [3, 17], [5, 7, 11]]
    for axes in _get_fftn_test_axes(shape)))
  def testIfftshift(self, shape, dtype, rng_factory, axes):
    rng = rng_factory(self.rng())
    args_maker = lambda: (rng(shape, dtype),)
    jnp_fn = lambda arg: jnp.fft.ifftshift(arg, axes=axes)
    np_fn = lambda arg: np.fft.ifftshift(arg, axes=axes)
    self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker)
Ejemplo n.º 12
0
class PredictTest(jtu.JaxTestCase):
    @jtu.parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name':
            '_train={}_test={}_network={}_logits={}_{}'.format(
                train, test, network, out_logits, name),
            'train_shape':
            train,
            'test_shape':
            test,
            'network':
            network,
            'out_logits':
            out_logits,
            'fn_and_kernel':
            fn
        } for train, test, network in zip(TRAIN_SHAPES, TEST_SHAPES, NETWORK)
                            for out_logits in OUTPUT_LOGITS
                            for name, fn in KERNELS.items()))
    def testNTKMSEPrediction(self, train_shape, test_shape, network,
                             out_logits, fn_and_kernel):

        key = random.PRNGKey(0)

        key, split = random.split(key)
        data_train = random.normal(split, train_shape)

        key, split = random.split(key)
        data_labels = np.array(
            random.bernoulli(split, shape=(train_shape[0], out_logits)),
            np.float32)

        key, split = random.split(key)
        data_test = random.normal(split, test_shape)

        params, f, ntk = fn_and_kernel(key, train_shape[1:], network,
                                       out_logits)

        # Regress to an MSE loss.
        loss = lambda params, x: \
            0.5 * np.mean((f(params, x) - data_labels) ** 2)
        grad_loss = jit(grad(loss))

        g_dd = ntk(data_train, None, 'ntk')
        g_td = ntk(data_test, data_train, 'ntk')

        predictor = predict.gradient_descent_mse(g_dd, data_labels, g_td)

        atol = ATOL
        rtol = RTOL
        step_size = 0.1

        if len(train_shape) > 2:
            # Hacky way to up the tolerance just for convolutions.
            atol = ATOL * 2
            rtol = RTOL * 2
            step_size = 0.1

        train_time = 100.0
        steps = int(train_time / step_size)

        opt_init, opt_update, get_params = optimizers.sgd(step_size)
        opt_state = opt_init(params)

        fx_initial_train = f(params, data_train)
        fx_initial_test = f(params, data_test)

        fx_pred_train, fx_pred_test = predictor(0.0, fx_initial_train,
                                                fx_initial_test)

        self.assertAllClose(fx_initial_train, fx_pred_train, True)
        self.assertAllClose(fx_initial_test, fx_pred_test, True)

        for i in range(steps):
            params = get_params(opt_state)
            opt_state = opt_update(i, grad_loss(params, data_train), opt_state)

        params = get_params(opt_state)
        fx_train = f(params, data_train)
        fx_test = f(params, data_test)

        fx_pred_train, fx_pred_test = predictor(train_time, fx_initial_train,
                                                fx_initial_test)

        fx_disp_train = np.sqrt(np.mean((fx_train - fx_initial_train)**2))
        fx_disp_test = np.sqrt(np.mean((fx_test - fx_initial_test)**2))

        fx_error_train = (fx_train - fx_pred_train) / fx_disp_train
        fx_error_test = (fx_test - fx_pred_test) / fx_disp_test

        self.assertAllClose(fx_error_train, np.zeros_like(fx_error_train),
                            True, rtol, atol)
        self.assertAllClose(fx_error_test, np.zeros_like(fx_error_test), True,
                            rtol, atol)

    @jtu.parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name':
            '_train={}_test={}_network={}_logits={}_{}'.format(
                train, test, network, out_logits, name),
            'train_shape':
            train,
            'test_shape':
            test,
            'network':
            network,
            'out_logits':
            out_logits,
            'fn_and_kernel':
            fn
        } for train, test, network in zip(TRAIN_SHAPES, TEST_SHAPES, NETWORK)
                            for out_logits in OUTPUT_LOGITS
                            for name, fn in KERNELS.items()))
    def testNTKGDPrediction(self, train_shape, test_shape, network, out_logits,
                            fn_and_kernel):
        key = random.PRNGKey(0)

        key, split = random.split(key)
        data_train = random.normal(split, train_shape)

        key, split = random.split(key)
        data_labels = np.array(
            random.bernoulli(split, shape=(train_shape[0], out_logits)),
            np.float32)

        key, split = random.split(key)
        data_test = random.normal(split, test_shape)

        params, f, ntk = fn_and_kernel(key, train_shape[1:], network,
                                       out_logits)

        # Regress to an MSE loss.
        loss = lambda y, y_hat: 0.5 * np.mean((y - y_hat)**2)
        grad_loss = jit(
            grad(lambda params, x: loss(f(params, x), data_labels)))

        g_dd = ntk(data_train, None, 'ntk')
        g_td = ntk(data_test, data_train, 'ntk')

        predictor = predict.gradient_descent(g_dd, data_labels, loss, g_td)

        atol = ATOL
        rtol = RTOL
        step_size = 0.5

        if len(train_shape) > 2:
            # Hacky way to up the tolerance just for convolutions.
            atol = ATOL * 2
            rtol = RTOL * 2
            step_size = 0.1

        train_time = 100.0
        steps = int(train_time / step_size)

        opt_init, opt_update, get_params = optimizers.sgd(step_size)
        opt_state = opt_init(params)

        fx_initial_train = f(params, data_train)
        fx_initial_test = f(params, data_test)

        fx_pred_train, fx_pred_test = predictor(0.0, fx_initial_train,
                                                fx_initial_test)

        self.assertAllClose(fx_initial_train, fx_pred_train, True)
        self.assertAllClose(fx_initial_test, fx_pred_test, True)

        for i in range(steps):
            params = get_params(opt_state)
            opt_state = opt_update(i, grad_loss(params, data_train), opt_state)

        params = get_params(opt_state)
        fx_train = f(params, data_train)
        fx_test = f(params, data_test)

        fx_pred_train, fx_pred_test = predictor(train_time, fx_initial_train,
                                                fx_initial_test)

        fx_disp_train = np.sqrt(np.mean((fx_train - fx_initial_train)**2))
        fx_disp_test = np.sqrt(np.mean((fx_test - fx_initial_test)**2))

        fx_error_train = (fx_train - fx_pred_train) / fx_disp_train
        fx_error_test = (fx_test - fx_pred_test) / fx_disp_test

        self.assertAllClose(fx_error_train, np.zeros_like(fx_error_train),
                            True, rtol, atol)
        self.assertAllClose(fx_error_test, np.zeros_like(fx_error_test), True,
                            rtol, atol)

    # TODO(schsam): Get this test passing with theoretical conv.
    @jtu.parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name':
            '_train={}_test={}_network={}_logits={}_{}'.format(
                train, test, network, out_logits, name),
            'train_shape':
            train,
            'test_shape':
            test,
            'network':
            network,
            'out_logits':
            out_logits,
            'fn_and_kernel':
            fn
        } for train, test, network in zip(TRAIN_SHAPES, TEST_SHAPES, NETWORK)
                            for out_logits in OUTPUT_LOGITS
                            for name, fn in KERNELS.items()
                            if len(train) == 2))
    def testNTKMomentumPrediction(self, train_shape, test_shape, network,
                                  out_logits, fn_and_kernel):
        key = random.PRNGKey(0)

        key, split = random.split(key)
        data_train = random.normal(split, train_shape)

        key, split = random.split(key)
        data_labels = np.array(
            random.bernoulli(split, shape=(train_shape[0], out_logits)),
            np.float32)

        key, split = random.split(key)
        data_test = random.normal(split, test_shape)

        params, f, ntk = fn_and_kernel(key, train_shape[1:], network,
                                       out_logits)

        # Regress to an MSE loss.
        loss = lambda y, y_hat: 0.5 * np.mean((y - y_hat)**2)
        grad_loss = jit(
            grad(lambda params, x: loss(f(params, x), data_labels)))

        g_dd = ntk(data_train, None, 'ntk')
        g_td = ntk(data_test, data_train, 'ntk')

        atol = ATOL
        rtol = RTOL
        step_size = 0.5

        if len(train_shape) > 2:
            # Hacky way to up the tolerance just for convolutions.
            atol = ATOL * 2
            rtol = RTOL * 2
            step_size = 0.1

        train_time = 100.0
        steps = int(train_time / np.sqrt(step_size))

        init, predictor, get = predict.momentum(g_dd, data_labels, loss,
                                                step_size, g_td)

        opt_init, opt_update, get_params = momentum(step_size, 0.9)
        opt_state = opt_init(params)

        fx_initial_train = f(params, data_train)
        fx_initial_test = f(params, data_test)

        lin_state = init(fx_initial_train, fx_initial_test)
        fx_pred_train, fx_pred_test = get(lin_state)

        self.assertAllClose(fx_initial_train, fx_pred_train, True)
        self.assertAllClose(fx_initial_test, fx_pred_test, True)

        for i in range(steps):
            params = get_params(opt_state)
            opt_state = opt_update(i, grad_loss(params, data_train), opt_state)

        params = get_params(opt_state)
        fx_train = f(params, data_train)
        fx_test = f(params, data_test)

        lin_state = predictor(lin_state, train_time)
        fx_pred_train, fx_pred_test = get(lin_state)

        fx_disp_train = np.sqrt(np.mean((fx_train - fx_initial_train)**2))
        fx_disp_test = np.sqrt(np.mean((fx_test - fx_initial_test)**2))

        fx_error_train = (fx_train - fx_pred_train) / fx_disp_train
        fx_error_test = (fx_test - fx_pred_test) / fx_disp_test

        self.assertAllClose(fx_error_train, np.zeros_like(fx_error_train),
                            True, rtol, atol)
        self.assertAllClose(fx_error_test, np.zeros_like(fx_error_test), True,
                            rtol, atol)

    @jtu.parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name':
            '_train={}_test={}_network={}_logits={}'.format(
                train, test, network, out_logits),
            'train_shape':
            train,
            'test_shape':
            test,
            'network':
            network,
            'out_logits':
            out_logits,
        } for train, test, network in zip(TRAIN_SHAPES[:-1], TEST_SHAPES[:-1],
                                          NETWORK[:-1])
                            for out_logits in OUTPUT_LOGITS))
    def testNTKMeanPrediction(self, train_shape, test_shape, network,
                              out_logits):

        key = random.PRNGKey(0)

        key, split = random.split(key)
        data_train = np.cos(random.normal(split, train_shape))

        key, split = random.split(key)
        data_labels = np.array(
            random.bernoulli(split, shape=(train_shape[0], out_logits)),
            np.float32)

        key, split = random.split(key)
        data_test = np.cos(random.normal(split, test_shape))
        _, _, kernel_fn = _build_network(train_shape[1:], network, out_logits)
        mean_pred, var = predict.gp_inference(kernel_fn,
                                              data_train,
                                              data_labels,
                                              data_test,
                                              'ntk',
                                              diag_reg=0.,
                                              compute_var=True)

        if xla_bridge.get_backend().platform == 'tpu':
            eigh = np.onp.linalg.eigh
        else:
            eigh = np.linalg.eigh

        self.assertEqual(var.shape[0], data_test.shape[0])
        min_eigh = np.min(eigh(var)[0])
        self.assertGreater(min_eigh + 1e-10, 0.)

        def mc_sampling(count=10):
            empirical_mean = 0.
            key = random.PRNGKey(100)
            init_fn, f, _ = _build_network(train_shape[1:], network,
                                           out_logits)
            _kernel_fn = empirical.empirical_kernel_fn(f)
            kernel_fn = jit(
                lambda x1, x2, params: _kernel_fn(x1, x2, params, 'ntk'))

            for _ in range(count):
                key, split = random.split(key)
                _, params = init_fn(split, train_shape)

                g_dd = kernel_fn(data_train, None, params)
                g_td = kernel_fn(data_test, data_train, params)
                predictor = predict.gradient_descent_mse(
                    g_dd, data_labels, g_td)

                fx_initial_train = f(params, data_train)
                fx_initial_test = f(params, data_test)

                _, fx_pred_test = predictor(1.0e8, fx_initial_train,
                                            fx_initial_test)
                empirical_mean += fx_pred_test
            return empirical_mean / count

        atol = ATOL
        rtol = RTOL
        mean_emp = mc_sampling(100)

        self.assertAllClose(mean_pred, mean_emp, True, rtol, atol)

    @jtu.parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name':
            '_train={}_test={}_network={}_logits={}'.format(
                train, test, network, out_logits),
            'train_shape':
            train,
            'test_shape':
            test,
            'network':
            network,
            'out_logits':
            out_logits,
        } for train, test, network in zip(TRAIN_SHAPES[:-1], TEST_SHAPES[:-1],
                                          NETWORK[:-1])
                            for out_logits in OUTPUT_LOGITS))
    def testGPInferenceGet(self, train_shape, test_shape, network, out_logits):

        key = random.PRNGKey(0)

        key, split = random.split(key)
        data_train = np.cos(random.normal(split, train_shape))

        key, split = random.split(key)
        data_labels = np.array(
            random.bernoulli(split, shape=(train_shape[0], out_logits)),
            np.float32)

        key, split = random.split(key)
        data_test = np.cos(random.normal(split, test_shape))
        _, _, kernel_fn = _build_network(train_shape[1:], network, out_logits)

        out = predict.gp_inference(kernel_fn,
                                   data_train,
                                   data_labels,
                                   data_test,
                                   'ntk',
                                   diag_reg=0.,
                                   compute_var=True)
        assert isinstance(out, predict.Gaussian)

        out = predict.gp_inference(kernel_fn,
                                   data_train,
                                   data_labels,
                                   data_test,
                                   'nngp',
                                   diag_reg=0.,
                                   compute_var=True)
        assert isinstance(out, predict.Gaussian)

        out = predict.gp_inference(kernel_fn,
                                   data_train,
                                   data_labels,
                                   data_test, ('ntk', ),
                                   diag_reg=0.,
                                   compute_var=True)
        assert len(out) == 1 and isinstance(out[0], predict.Gaussian)

        out = predict.gp_inference(kernel_fn,
                                   data_train,
                                   data_labels,
                                   data_test, ('ntk', 'nngp'),
                                   diag_reg=0.,
                                   compute_var=True)
        assert (len(out) == 2 and isinstance(out[0], predict.Gaussian)
                and isinstance(out[1], predict.Gaussian))

        out2 = predict.gp_inference(kernel_fn,
                                    data_train,
                                    data_labels,
                                    data_test, ('nngp', 'ntk'),
                                    diag_reg=0.,
                                    compute_var=True)
        self.assertAllClose(out[0], out2[1], True)
        self.assertAllClose(out[1], out2[0], True)

    @jtu.parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name':
            '_train={}_test={}_network={}_logits={}_get={}'.format(
                train, test, network, out_logits, get),
            'train_shape':
            train,
            'test_shape':
            test,
            'network':
            network,
            'out_logits':
            out_logits,
            'get':
            get,
        } for train, test, network in zip(TRAIN_SHAPES[:-1], TEST_SHAPES[:-1],
                                          NETWORK[:-1])
                            for out_logits in OUTPUT_LOGITS for get in GETS))
    def testInfiniteTimeAgreement(self, train_shape, test_shape, network,
                                  out_logits, get):

        key = random.PRNGKey(0)

        key, split = random.split(key)
        data_train = np.cos(random.normal(split, train_shape))

        key, split = random.split(key)
        data_labels = np.array(
            random.bernoulli(split, shape=(train_shape[0], out_logits)),
            np.float32)

        key, split = random.split(key)
        data_test = np.cos(random.normal(split, test_shape))
        _, _, kernel_fn = _build_network(train_shape[1:], network, out_logits)

        reg = 1e-7
        inf_prediction = predict.gp_inference(kernel_fn,
                                              data_train,
                                              data_labels,
                                              data_test,
                                              get,
                                              diag_reg=reg,
                                              compute_var=True)
        prediction = predict.gradient_descent_mse_gp(kernel_fn,
                                                     data_train,
                                                     data_labels,
                                                     data_test,
                                                     get,
                                                     diag_reg=reg,
                                                     compute_var=True)

        finite_prediction = prediction(np.inf)

        self.assertAllClose(inf_prediction, finite_prediction, True)

    @jtu.parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name':
            '_train={}_test={}_network={}_logits={}'.format(
                train, test, network, out_logits),
            'train_shape':
            train,
            'test_shape':
            test,
            'network':
            network,
            'out_logits':
            out_logits,
        } for train, test, network in zip(TRAIN_SHAPES[:-1], TEST_SHAPES[:-1],
                                          NETWORK[:-1])
                            for out_logits in OUTPUT_LOGITS))
    def testZeroTimeAgreement(self, train_shape, test_shape, network,
                              out_logits):
        """Test that the NTK and NNGP agree at t=0."""

        key = random.PRNGKey(0)

        key, split = random.split(key)
        data_train = np.cos(random.normal(split, train_shape))

        key, split = random.split(key)
        data_labels = np.array(
            random.bernoulli(split, shape=(train_shape[0], out_logits)),
            np.float32)

        key, split = random.split(key)
        data_test = np.cos(random.normal(split, test_shape))
        _, _, ker_fun = _build_network(train_shape[1:], network, out_logits)

        reg = 1e-7
        prediction = predict.gradient_descent_mse_gp(ker_fun,
                                                     data_train,
                                                     data_labels,
                                                     data_test,
                                                     diag_reg=reg,
                                                     get=('NTK', 'NNGP'),
                                                     compute_var=True)

        zero_prediction = prediction(0.0)

        self.assertAllClose(zero_prediction.ntk, zero_prediction.nngp, True)
        reference = (np.zeros((test_shape[0], out_logits)),
                     ker_fun(data_test, data_test, get='nngp'))
        self.assertAllClose((reference, ) * 2, zero_prediction, True)
Ejemplo n.º 13
0
class LaxBackedScipyTests(jtu.JaxTestCase):
    """Tests for LAX-backed Scipy implementation."""
    def _GetArgsMaker(self, rng, shapes, dtypes):
        return lambda: [
            rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)
        ]

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shapes={}_axis={}_keepdims={}_return_sign={}_use_b_{}".format(
                jtu.format_shape_dtype_string(shapes, dtype), axis, keepdims,
                return_sign, use_b),
            # TODO(b/133842870): re-enable when exp(nan) returns NaN on CPU.
            "shapes":
            shapes,
            "dtype":
            dtype,
            "axis":
            axis,
            "keepdims":
            keepdims,
            "return_sign":
            return_sign,
            "use_b":
            use_b
        } for shape_group in compatible_shapes for dtype in float_dtypes
                            for use_b in [False, True]
                            for shapes in itertools.product(
                                *((shape_group,
                                   shape_group) if use_b else (shape_group, )))
                            for axis in range(
                                -max(len(shape) for shape in shapes),
                                max(len(shape) for shape in shapes))
                            for keepdims in [False, True]
                            for return_sign in [False, True]))
    @jtu.ignore_warning(category=RuntimeWarning,
                        message="invalid value encountered in .*")
    def testLogSumExp(self, shapes, dtype, axis, keepdims, return_sign, use_b):
        if jtu.device_under_test() != "cpu":
            rng = jtu.rand_some_inf_and_nan(self.rng())
        else:
            rng = jtu.rand_default(self.rng())
        # TODO(mattjj): test autodiff
        if use_b:

            def scipy_fun(array_to_reduce, scale_array):
                return osp_special.logsumexp(array_to_reduce,
                                             axis,
                                             keepdims=keepdims,
                                             return_sign=return_sign,
                                             b=scale_array)

            def lax_fun(array_to_reduce, scale_array):
                return lsp_special.logsumexp(array_to_reduce,
                                             axis,
                                             keepdims=keepdims,
                                             return_sign=return_sign,
                                             b=scale_array)

            args_maker = lambda: [rng(shapes[0], dtype), rng(shapes[1], dtype)]
        else:

            def scipy_fun(array_to_reduce):
                return osp_special.logsumexp(array_to_reduce,
                                             axis,
                                             keepdims=keepdims,
                                             return_sign=return_sign)

            def lax_fun(array_to_reduce):
                return lsp_special.logsumexp(array_to_reduce,
                                             axis,
                                             keepdims=keepdims,
                                             return_sign=return_sign)

            args_maker = lambda: [rng(shapes[0], dtype)]
        self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker)
        self._CompileAndCheck(lax_fun, args_maker)

    @parameterized.named_parameters(
        itertools.chain.from_iterable(
            jtu.cases_from_list(
                {
                    "testcase_name":
                    jtu.format_test_name_suffix(rec.test_name, shapes, dtypes),
                    "rng_factory":
                    rec.rng_factory,
                    "shapes":
                    shapes,
                    "dtypes":
                    dtypes,
                    "test_autodiff":
                    rec.test_autodiff,
                    "nondiff_argnums":
                    rec.nondiff_argnums,
                    "scipy_op":
                    getattr(osp_special, rec.name),
                    "lax_op":
                    getattr(lsp_special, rec.name)
                } for shapes in itertools.combinations_with_replacement(
                    all_shapes, rec.nargs)
                for dtypes in (itertools.combinations_with_replacement(
                    rec.dtypes, rec.nargs) if isinstance(rec.dtypes, list) else
                               itertools.product(*rec.dtypes)))
            for rec in JAX_SPECIAL_FUNCTION_RECORDS))
    def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes,
                            dtypes, test_autodiff, nondiff_argnums):
        rng = rng_factory(self.rng())
        args_maker = self._GetArgsMaker(rng, shapes, dtypes)
        args = args_maker()
        self.assertAllClose(scipy_op(*args),
                            lax_op(*args),
                            atol=1e-3,
                            rtol=1e-3,
                            check_dtypes=False)
        self._CompileAndCheck(lax_op, args_maker, rtol=1e-4)

        if test_autodiff:

            def partial_lax_op(*vals):
                list_args = list(vals)
                for i in nondiff_argnums:
                    list_args.insert(i, args[i])
                return lax_op(*list_args)

            assert list(nondiff_argnums) == sorted(set(nondiff_argnums))
            diff_args = [
                x for i, x in enumerate(args) if i not in nondiff_argnums
            ]
            jtu.check_grads(partial_lax_op,
                            diff_args,
                            order=1,
                            atol=jtu.if_device_under_test("tpu", .1, 1e-3),
                            rtol=.1,
                            eps=1e-3)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inshape={}_d={}".format(
                jtu.format_shape_dtype_string(shape, dtype), d),
            "rng_factory":
            jtu.rand_positive,
            "shape":
            shape,
            "dtype":
            dtype,
            "d":
            d
        } for shape in all_shapes for dtype in float_dtypes
                            for d in [1, 2, 5]))
    def testMultigammaln(self, rng_factory, shape, dtype, d):
        def scipy_fun(a):
            return osp_special.multigammaln(a, d)

        def lax_fun(a):
            return lsp_special.multigammaln(a, d)

        rng = rng_factory(self.rng())
        args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.]
        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                tol={
                                    np.float32: 1e-3,
                                    np.float64: 1e-14
                                })
        self._CompileAndCheck(lax_fun, args_maker)

    def testIssue980(self):
        x = np.full((4, ), -1e20, dtype=np.float32)
        self.assertAllClose(np.zeros((4, ), dtype=np.float32),
                            lsp_special.expit(x))

    def testIssue3758(self):
        x = np.array([1e5, 1e19, 1e10], dtype=np.float32)
        q = np.array([1., 40., 30.], dtype=np.float32)
        self.assertAllClose(np.array([1., 0., 0.], dtype=np.float32),
                            lsp_special.zeta(x, q))

    def testXlogyShouldReturnZero(self):
        self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False)

    def testGradOfXlogyAtZero(self):
        partial_xlogy = functools.partial(lsp_special.xlogy, 0.)
        self.assertAllClose(api.grad(partial_xlogy)(0.),
                            0.,
                            check_dtypes=False)

    def testXlog1pyShouldReturnZero(self):
        self.assertAllClose(lsp_special.xlog1py(0., -1.),
                            0.,
                            check_dtypes=False)

    def testGradOfXlog1pyAtZero(self):
        partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.)
        self.assertAllClose(api.grad(partial_xlog1py)(-1.),
                            0.,
                            check_dtypes=False)
Ejemplo n.º 14
0
class SMapTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype
            } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
    def test_bond_no_type_static(self, spatial_dimension, dtype):
        harmonic = lambda dr, **kwargs: (dr - f32(1))**f32(2)
        disp, _ = space.free()
        metric = space.metric(disp)

        mapped = smap.bond(harmonic, metric, np.array([[0, 1], [0, 2]], i32))

        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)

            accum = harmonic(metric(R[0], R[1])) + harmonic(metric(R[0], R[2]))

            self.assertAllClose(mapped(R), dtype(accum))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype
            } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
    def test_bond_no_type_dynamic(self, spatial_dimension, dtype):
        harmonic = lambda dr, **kwargs: (dr - f32(1))**f32(2)
        disp, _ = space.free()
        metric = space.metric(disp)

        mapped = smap.bond(harmonic, metric)
        bonds = np.array([[0, 1], [0, 2]], i32)

        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)

            accum = harmonic(metric(R[0], R[1])) + harmonic(metric(R[0], R[2]))

            self.assertAllClose(mapped(R, bonds), dtype(accum))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype
            } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
    def test_bond_type_static(self, spatial_dimension, dtype):
        harmonic = lambda dr, sigma, **kwargs: (dr - sigma)**f32(2)
        disp, _ = space.free()
        metric = space.metric(disp)

        sigma = np.array([1.0, 2.0], f32)

        mapped = smap.bond(harmonic,
                           metric,
                           np.array([[0, 1], [0, 2]], i32),
                           np.array([0, 1], i32),
                           sigma=sigma)

        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)

            accum = harmonic(metric(R[0], R[1]), 1) + harmonic(
                metric(R[0], R[2]), 2)

            self.assertAllClose(mapped(R), dtype(accum))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype
            } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
    def test_bond_type_dynamic(self, spatial_dimension, dtype):
        harmonic = lambda dr, sigma, **kwargs: (dr - sigma)**f32(2)
        disp, _ = space.free()
        metric = space.metric(disp)

        sigma = np.array([1.0, 2.0], f32)

        mapped = smap.bond(harmonic, metric, sigma=sigma)
        bonds = np.array([[0, 1], [0, 2]], i32)
        bond_types = np.array([0, 1], i32)

        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)

            accum = harmonic(metric(R[0], R[1]), 1) + harmonic(
                metric(R[0], R[2]), 2)

            self.assertAllClose(mapped(R, bonds, bond_types), dtype(accum))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype
            } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
    def test_bond_params_dynamic(self, spatial_dimension, dtype):
        harmonic = lambda dr, sigma, **kwargs: (dr - sigma)**f32(2)
        disp, _ = space.free()
        metric = space.metric(disp)

        sigma = np.array([1.0, 2.0], f32)

        mapped = smap.bond(harmonic, metric, sigma=1.0)
        bonds = np.array([[0, 1], [0, 2]], i32)

        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)

            accum = harmonic(metric(R[0], R[1]), 1) + harmonic(
                metric(R[0], R[2]), 2)

            self.assertAllClose(mapped(R, bonds, sigma=sigma), dtype(accum))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype
            } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
    def test_bond_per_bond_static(self, spatial_dimension, dtype):
        harmonic = lambda dr, sigma, **kwargs: (dr - sigma)**f32(2)
        disp, _ = space.free()
        metric = space.metric(disp)

        sigma = np.array([1.0, 2.0], f32)

        mapped = smap.bond(harmonic,
                           metric,
                           np.array([[0, 1], [0, 2]], i32),
                           sigma=sigma)

        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)

            accum = harmonic(metric(R[0], R[1]), 1) + harmonic(
                metric(R[0], R[2]), 2)

            self.assertAllClose(mapped(R), dtype(accum))

    def test_get_species_parameters(self):
        species = [(0, 0), (0, 1), (1, 0), (1, 1)]
        params = np.array([[2.0, 3.0], [3.0, 1.0]])
        global_params = 3.0
        self.assertAllClose(smap._get_species_parameters(params, species[0]),
                            2.0)
        self.assertAllClose(smap._get_species_parameters(params, species[1]),
                            3.0)
        self.assertAllClose(smap._get_species_parameters(params, species[2]),
                            3.0)
        self.assertAllClose(smap._get_species_parameters(params, species[3]),
                            1.0)
        for s in species:
            self.assertAllClose(smap._get_species_parameters(global_params, s),
                                3.0)

    def test_get_matrix_parameters(self):
        params = np.array([1.0, 2.0])
        params_mat_test = np.array([[1.0, 1.5], [1.5, 2.0]])
        params_mat = smap._get_matrix_parameters(params, lambda x, y: 0.5 *
                                                 (x + y))
        self.assertAllClose(params_mat, params_mat_test)

        params_mat_direct = np.array([[1.0, 2.0], [3.0, 4.0]])
        self.assertAllClose(
            smap._get_matrix_parameters(params_mat_direct, None),
            params_mat_direct)

        params_scalar = 1.0
        self.assertAllClose(smap._get_matrix_parameters(params_scalar, None),
                            params_scalar)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype
            } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
    def test_pair_no_species_scalar(self, spatial_dimension, dtype):
        square = lambda dr: dr**2
        displacement, _ = space.free()
        metric = lambda Ra, Rb, **kwargs: \
            np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1)

        mapped_square = smap.pair(square, metric)
        metric = space.map_product(metric)

        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            self.assertAllClose(
                mapped_square(R),
                np.array(0.5 * np.sum(square(metric(R, R))), dtype=dtype))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype
            } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
    def test_pair_no_species_scalar_dynamic(self, spatial_dimension, dtype):
        square = lambda dr, epsilon: epsilon * dr**2
        displacement, _ = space.free()
        metric = lambda Ra, Rb, **kwargs: \
            np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1)

        mapped_square = smap.pair(square, metric, epsilon=1.0)
        metric = space.map_product(metric)

        key = random.PRNGKey(0)
        for _ in range(STOCHASTIC_SAMPLES):
            key, split1, split2 = random.split(key, 3)
            R = random.uniform(split1, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            epsilon = random.uniform(split2, (PARTICLE_COUNT, ), dtype=dtype)
            mat_epsilon = 0.5 * (epsilon[:, np.newaxis] +
                                 epsilon[np.newaxis, :])
            self.assertAllClose(
                mapped_square(R, epsilon=epsilon),
                np.array(0.5 * np.sum(square(metric(R, R), mat_epsilon)),
                         dtype=dtype))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype
            } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
    def test_pair_no_species_vector(self, spatial_dimension, dtype):
        square = lambda dr: np.sum(dr**2, axis=2)
        disp, _ = space.free()

        mapped_square = smap.pair(square, disp)

        disp = space.map_product(disp)
        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            mapped_ref = np.array(0.5 * np.sum(square(disp(R, R))),
                                  dtype=dtype)
            self.assertAllClose(mapped_square(R), mapped_ref)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype
            } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
    def test_pair_no_species_vector_nonadditive(self, spatial_dimension,
                                                dtype):
        square = lambda dr, params: params * np.sum(dr**2, axis=2)
        disp, _ = space.free()

        mapped_square = smap.pair(square, disp, params=lambda x, y: x * y)

        disp = space.map_product(disp)
        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, R_key, params_key = random.split(key, 3)
            R = random.uniform(R_key, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            params = random.uniform(params_key, (PARTICLE_COUNT, ),
                                    dtype=dtype,
                                    minval=0.1,
                                    maxval=1.5)
            pp_params = params[None, :] * params[:, None]
            mapped_ref = np.array(0.5 * np.sum(square(disp(R, R), pp_params)),
                                  dtype=dtype)
            self.assertAllClose(mapped_square(R, params=params), mapped_ref)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype,
            } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
    def test_pair_static_species_scalar(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        square = lambda dr, param=1.0: param * dr**2
        params = f32(np.array([[1.0, 2.0], [2.0, 3.0]]))

        key, split = random.split(key)
        species = random.randint(split, (PARTICLE_COUNT, ), 0, 2)
        displacement, _ = space.free()
        metric = lambda Ra, Rb, **kwargs: \
            np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1)

        mapped_square = smap.pair(square,
                                  metric,
                                  species=species,
                                  param=params)

        metric = space.map_product(metric)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            total = 0.0
            for i in range(2):
                for j in range(2):
                    param = params[i, j]
                    R_1 = R[species == i]
                    R_2 = R[species == j]
                    total = total + 0.5 * np.sum(
                        square(metric(R_1, R_2), param))
            self.assertAllClose(mapped_square(R), np.array(total, dtype=dtype))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype,
            } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
    def test_pair_static_species_scalar_dynamic(self, spatial_dimension,
                                                dtype):
        key = random.PRNGKey(0)

        square = lambda dr, param=1.0: param * dr**2

        key, split = random.split(key)
        species = random.randint(split, (PARTICLE_COUNT, ), 0, 2)
        displacement, _ = space.free()
        metric = lambda Ra, Rb, **kwargs: \
            np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1)

        mapped_square = smap.pair(square, metric, species=species, param=1.0)

        metric = space.map_product(metric)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split1, split2 = random.split(key, 3)
            R = random.uniform(split1, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            params = random.uniform(split2, (2, 2), dtype=dtype)
            params = f32(0.5) * (params + params.T)
            total = 0.0
            for i in range(2):
                for j in range(2):
                    param = params[i, j]
                    R_1 = R[species == i]
                    R_2 = R[species == j]
                    total = total + 0.5 * np.sum(
                        square(metric(R_1, R_2), param))
            self.assertAllClose(mapped_square(R, param=params),
                                np.array(total, dtype=dtype))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype,
            } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
    def test_pair_scalar_dummy_arg(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        square = lambda dr, param=f32(1.0), **unused_kwargs: param * dr**2

        key, split = random.split(key)
        R = random.normal(key, (PARTICLE_COUNT, spatial_dimension),
                          dtype=dtype)
        displacement, shift = space.free()

        mapped = smap.pair(square, space.metric(displacement))

        mapped(R, t=f32(0))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype
            } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
    def test_pair_static_species_vector(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        square = lambda dr, param=1.0: param * np.sum(dr**2, axis=2)
        params = np.array([[1.0, 2.0], [2.0, 3.0]], dtype=f32)

        key, split = random.split(key)
        species = random.randint(split, (PARTICLE_COUNT, ), 0, 2)
        disp, _ = space.free()

        mapped_square = smap.pair(square, disp, species=species, param=params)

        disp = space.map_product(disp)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            total = 0.0
            for i in range(2):
                for j in range(2):
                    param = params[i, j]
                    R_1 = R[species == i]
                    R_2 = R[species == j]
                    total = total + 0.5 * np.sum(square(disp(R_1, R_2), param))
            self.assertAllClose(mapped_square(R), np.array(total, dtype=dtype))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype,
            } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
    def test_pair_dynamic_species_scalar(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        square = lambda dr, param=1.0: param * dr**2
        params = f32(np.array([[1.0, 2.0], [2.0, 3.0]]))

        key, split = random.split(key)
        species = random.randint(split, (PARTICLE_COUNT, ), 0, 2)
        displacement, _ = space.free()
        metric = space.metric(displacement)

        mapped_square = smap.pair(square, metric, species=2, param=params)

        metric = space.map_product(metric)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            total = 0.0
            for i in range(2):
                for j in range(2):
                    param = params[i, j]
                    R_1 = R[species == i]
                    R_2 = R[species == j]
                    total = total + 0.5 * np.sum(
                        square(metric(R_1, R_2), param))
            self.assertAllClose(mapped_square(R, species),
                                np.array(total, dtype=dtype))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype
            } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
    def test_pair_dynamic_species_vector(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        square = lambda dr, param=1.0: param * np.sum(dr**2, axis=-1)
        params = f32(np.array([[1.0, 2.0], [2.0, 3.0]]))

        key, split = random.split(key)
        species = random.randint(split, (PARTICLE_COUNT, ), 0, 2)
        disp, _ = space.free()

        mapped_square = smap.pair(square, disp, species=2, param=params)

        disp = vmap(vmap(disp, (0, None), 0), (None, 0), 0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension),
                               dtype=dtype)
            total = 0.0
            for i in range(2):
                for j in range(2):
                    param = params[i, j]
                    R_1 = R[species == i]
                    R_2 = R[species == j]
                    total = total + 0.5 * np.sum(square(disp(R_1, R_2), param))
            self.assertAllClose(mapped_square(R, species),
                                np.array(total, dtype=dtype))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name': (f'_dim={dim}_dtype={dtype.__name__}'
                              f'_format={format}'),
            'spatial_dimension':
            dim,
            'dtype':
            dtype,
            'format':
            format
        } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE
                            for format in NEIGHBOR_LIST_FORMAT))
    def test_pair_neighbor_list_scalar(self, spatial_dimension, dtype, format):
        key = random.PRNGKey(0)

        def truncated_square(dr, sigma):
            return np.where(dr < sigma, dr**2, f32(0.))

        N = NEIGHBOR_LIST_PARTICLE_COUNT
        box_size = 4. * N**(1. / spatial_dimension)

        key, split = random.split(key)
        disp, _ = space.periodic(box_size)
        d = space.metric(disp)

        neighbor_square = smap.pair_neighbor_list(truncated_square,
                                                  d,
                                                  sigma=1.0)
        neighbor_square = jit(neighbor_square)
        mapped_square = jit(smap.pair(truncated_square, d, sigma=1.0))

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = box_size * random.uniform(split, (N, spatial_dimension),
                                          dtype=dtype)
            sigma = random.uniform(key, (), minval=0.5, maxval=2.5)
            neighbor_fn = partition.neighbor_list(disp,
                                                  box_size,
                                                  sigma,
                                                  0.0,
                                                  format=format)
            nbrs = neighbor_fn.allocate(R)
            self.assertAllClose(mapped_square(R, sigma=sigma),
                                neighbor_square(R, nbrs, sigma=sigma))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name': (f'_dim={dim}_dtype={dtype.__name__}'
                              f'_format={format}'),
            'spatial_dimension':
            dim,
            'dtype':
            dtype,
            'format':
            format
        } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE
                            for format in NEIGHBOR_LIST_FORMAT))
    def test_pair_neighbor_list_scalar_diverging_potential(
            self, spatial_dimension, dtype, format):
        key = random.PRNGKey(0)

        def potential(dr, sigma):
            return np.where(dr < sigma, dr**-6, f32(0.))

        N = NEIGHBOR_LIST_PARTICLE_COUNT
        box_size = 4. * N**(1. / spatial_dimension)

        key, split = random.split(key)
        disp, _ = space.periodic(box_size)
        d = space.metric(disp)

        neighbor_square = jit(smap.pair_neighbor_list(potential, d, sigma=1.0))
        mapped_square = jit(smap.pair(potential, d, sigma=1.0))

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = box_size * random.uniform(split, (N, spatial_dimension),
                                          dtype=dtype)
            sigma = random.uniform(key, (), minval=0.5, maxval=2.5)
            neighbor_fn = partition.neighbor_list(disp,
                                                  box_size,
                                                  sigma,
                                                  0.0,
                                                  format=format)
            nbrs = neighbor_fn.allocate(R)
            self.assertAllClose(mapped_square(R, sigma=sigma),
                                neighbor_square(R, nbrs, sigma=sigma))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name': (f'_dim={dim}_dtype={dtype.__name__}'
                              f'_format={format}'),
            'spatial_dimension':
            dim,
            'dtype':
            dtype,
            'format':
            format
        } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE
                            for format in NEIGHBOR_LIST_FORMAT))
    def test_pair_neighbor_list_force_scalar_diverging_potential(
            self, spatial_dimension, dtype, format):
        key = random.PRNGKey(0)

        def potential(dr, sigma):
            return np.where(dr < sigma, dr**-6, f32(0.))

        N = NEIGHBOR_LIST_PARTICLE_COUNT
        box_size = 4. * N**(1. / spatial_dimension)

        key, split = random.split(key)
        disp, _ = space.periodic(box_size)
        d = space.metric(disp)

        neighbor_square = smap.pair_neighbor_list(potential, d, sigma=1.0)
        neighbor_square = jit(quantity.force(neighbor_square))
        mapped_square = jit(quantity.force(smap.pair(potential, d, sigma=1.0)))

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = box_size * random.uniform(split, (N, spatial_dimension),
                                          dtype=dtype)
            sigma = random.uniform(key, (), minval=0.5, maxval=4.5)
            neighbor_fn = partition.neighbor_list(disp,
                                                  box_size,
                                                  sigma,
                                                  0.0,
                                                  format=format)
            nbrs = neighbor_fn.allocate(R)
            self.assertAllClose(mapped_square(R, sigma=sigma),
                                neighbor_square(R, nbrs, sigma=sigma))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name': (f'_dim={dim}_dtype={dtype.__name__}'
                              f'_format={format}'),
            'spatial_dimension':
            dim,
            'dtype':
            dtype,
            'format':
            format
        } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE
                            for format in NEIGHBOR_LIST_FORMAT))
    def test_pair_neighbor_list_scalar_params_no_species(
            self, spatial_dimension, dtype, format):
        key = random.PRNGKey(0)

        def truncated_square(dr, sigma):
            return np.where(dr < sigma, dr**2, f32(0.))

        N = NEIGHBOR_LIST_PARTICLE_COUNT
        box_size = 2. * N**(1. / spatial_dimension)

        key, split = random.split(key)
        disp, _ = space.periodic(box_size)
        d = space.metric(disp)

        neighbor_square = smap.pair_neighbor_list(truncated_square,
                                                  d,
                                                  sigma=1.0)
        neighbor_square = jit(neighbor_square)
        mapped_square = jit(smap.pair(truncated_square, d, sigma=1.0))

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = box_size * random.uniform(split, (N, spatial_dimension),
                                          dtype=dtype)
            sigma = random.uniform(key, (N, ), minval=0.5, maxval=1.5)
            neighbor_fn = partition.neighbor_list(disp,
                                                  box_size,
                                                  np.max(sigma),
                                                  0.,
                                                  format=format)
            nbrs = neighbor_fn.allocate(R)
            self.assertAllClose(mapped_square(R, sigma=sigma),
                                neighbor_square(R, nbrs, sigma=sigma))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name': (f'_dim={dim}_dtype={dtype.__name__}'
                              f'_format={format}'),
            'spatial_dimension':
            dim,
            'dtype':
            dtype,
            'format':
            format
        } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE
                            for format in NEIGHBOR_LIST_FORMAT))
    def test_pair_neighbor_list_scalar_params_matrix(self, spatial_dimension,
                                                     dtype, format):
        key = random.PRNGKey(0)

        def truncated_square(dr, sigma):
            return np.where(dr < sigma, dr**2, f32(0.))

        N = NEIGHBOR_LIST_PARTICLE_COUNT
        box_size = 2. * N**(1. / spatial_dimension)

        key, split = random.split(key)
        disp, _ = space.periodic(box_size)
        d = space.metric(disp)

        neighbor_square = smap.pair_neighbor_list(truncated_square,
                                                  d,
                                                  sigma=1.0)
        neighbor_square = jit(neighbor_square)
        mapped_square = jit(smap.pair(truncated_square, d, sigma=1.0))

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = box_size * random.uniform(split, (N, spatial_dimension),
                                          dtype=dtype)
            sigma = random.uniform(key, (N, N), minval=0.5, maxval=1.5)
            sigma = 0.5 * (sigma + sigma.T)
            neighbor_fn = partition.neighbor_list(disp,
                                                  box_size,
                                                  np.max(sigma),
                                                  0.,
                                                  format=format)
            nbrs = neighbor_fn.allocate(R)
            self.assertAllClose(mapped_square(R, sigma=sigma),
                                neighbor_square(R, nbrs, sigma=sigma))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name': (f'_dim={dim}_dtype={dtype.__name__}'
                              f'_format={format}'),
            'spatial_dimension':
            dim,
            'dtype':
            dtype,
            'format':
            format
        } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE
                            for format in NEIGHBOR_LIST_FORMAT))
    def test_pair_neighbor_list_scalar_params_species(self, spatial_dimension,
                                                      dtype, format):
        key = random.PRNGKey(0)

        def truncated_square(dr, sigma):
            return np.where(dr < sigma, dr**2, f32(0.))

        N = NEIGHBOR_LIST_PARTICLE_COUNT
        box_size = 2. * N**(1. / spatial_dimension)
        species = np.zeros((N, ), np.int32)
        species = np.where(np.arange(N) > N / 3, 1, species)
        species = np.where(np.arange(N) > 2 * N / 3, 2, species)

        key, split = random.split(key)
        disp, _ = space.periodic(box_size)
        d = space.metric(disp)

        neighbor_square = smap.pair_neighbor_list(truncated_square,
                                                  d,
                                                  species=species,
                                                  sigma=1.0)
        neighbor_square = jit(neighbor_square)
        mapped_square = smap.pair(truncated_square,
                                  d,
                                  species=species,
                                  sigma=1.0)
        mapped_square = jit(mapped_square)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = box_size * random.uniform(split, (N, spatial_dimension),
                                          dtype=dtype)
            sigma = random.uniform(key, (3, 3), minval=0.5, maxval=1.5)
            sigma = 0.5 * (sigma + sigma.T)
            neighbor_fn = partition.neighbor_list(disp,
                                                  box_size,
                                                  np.max(sigma),
                                                  0.,
                                                  format=format)
            nbrs = neighbor_fn.allocate(R)
            self.assertAllClose(mapped_square(R, sigma=sigma),
                                neighbor_square(R, nbrs, sigma=sigma))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name': (f'_dim={dim}_dtype={dtype.__name__}'
                              f'_format={str(format).split(".")[-1]}'),
            'spatial_dimension':
            dim,
            'dtype':
            dtype,
            'format':
            format
        } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE
                            for format in NEIGHBOR_LIST_FORMAT))
    def test_pair_neighbor_list_vector(self, spatial_dimension, dtype, format):
        if format is partition.OrderedSparse:
            self.skipTest('Vector valued pair_neighbor_list not supported.')
        key = random.PRNGKey(0)

        def truncated_square(dR, sigma):
            dr = np.reshape(space.distance(dR), dR.shape[:-1] + (1, ))
            return np.where(dr < sigma, dR**2, f32(0.))

        N = PARTICLE_COUNT
        box_size = 2. * N**(1. / spatial_dimension)

        key, split = random.split(key)
        disp, _ = space.periodic(box_size)

        neighbor_square = jit(
            smap.pair_neighbor_list(truncated_square,
                                    disp,
                                    sigma=1.0,
                                    reduce_axis=(1, )))
        mapped_square = jit(
            smap.pair(truncated_square, disp, sigma=1.0, reduce_axis=(1, )))

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = box_size * random.uniform(split, (N, spatial_dimension),
                                          dtype=dtype)
            sigma = random.uniform(key, (), minval=0.5, maxval=1.5)
            neighbor_fn = partition.neighbor_list(disp,
                                                  box_size,
                                                  sigma,
                                                  0.,
                                                  format=format)
            nbrs = neighbor_fn.allocate(R)
            self.assertAllClose(mapped_square(R, sigma=sigma),
                                neighbor_square(R, nbrs, sigma=sigma))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name': (f'_dim={dim}_dtype={dtype.__name__}'
                              f'_format={format}'),
            'spatial_dimension':
            dim,
            'dtype':
            dtype,
            'format':
            format
        } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE
                            for format in NEIGHBOR_LIST_FORMAT))
    def test_pair_neighbor_list_vector_nonadditive(self, spatial_dimension,
                                                   dtype, format):

        if format is partition.OrderedSparse:
            self.skipTest('Vector valued pair_neighbor_list not supported.')

        key = random.PRNGKey(0)

        def truncated_square(dR, sigma):
            dr = space.distance(dR)
            return np.where(dr < sigma, dr**2, f32(0.))

        N = PARTICLE_COUNT
        box_size = 2. * N**(1. / spatial_dimension)

        key, split = random.split(key)
        disp, _ = space.periodic(box_size)

        neighbor_square = jit(
            smap.pair_neighbor_list(truncated_square,
                                    disp,
                                    sigma=lambda x, y: x * y,
                                    reduce_axis=(1, )))
        mapped_square = jit(
            smap.pair(truncated_square, disp, sigma=1.0, reduce_axis=(1, )))

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = box_size * random.uniform(split, (N, spatial_dimension),
                                          dtype=dtype)
            sigma = random.uniform(key, (N, ), minval=0.5, maxval=1.5)
            sigma_pair = sigma[:, None] * sigma[None, :]
            neighbor_fn = partition.neighbor_list(disp,
                                                  box_size,
                                                  np.max(sigma)**2,
                                                  0.,
                                                  format=format)
            nbrs = neighbor_fn.allocate(R)
            self.assertAllClose(mapped_square(R, sigma=sigma_pair),
                                neighbor_square(R, nbrs, sigma=sigma))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name': (f'_dim={dim}_dtype={dtype.__name__}'
                              f'_format={format}'),
            'spatial_dimension':
            dim,
            'dtype':
            dtype,
            'format':
            format
        } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE
                            for format in NEIGHBOR_LIST_FORMAT))
    def test_pair_neighbor_list_scalar_nonadditive(self, spatial_dimension,
                                                   dtype, format):
        key = random.PRNGKey(0)

        def truncated_square(dR, sigma):
            dr = space.distance(dR)
            return np.where(dr < sigma, dr**2, f32(0.))

        N = PARTICLE_COUNT
        box_size = 2. * N**(1. / spatial_dimension)

        key, split = random.split(key)
        disp, _ = space.periodic(box_size)

        neighbor_square = jit(
            smap.pair_neighbor_list(truncated_square,
                                    disp,
                                    sigma=lambda x, y: x * y))
        mapped_square = jit(smap.pair(truncated_square, disp, sigma=1.0))

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = box_size * random.uniform(split, (N, spatial_dimension),
                                          dtype=dtype)
            sigma = random.uniform(key, (N, ), minval=0.5, maxval=1.5)
            sigma_pair = sigma[:, None] * sigma[None, :]
            neighbor_fn = partition.neighbor_list(disp,
                                                  box_size,
                                                  np.max(sigma)**2,
                                                  0.,
                                                  format=format)
            nbrs = neighbor_fn.allocate(R)
            self.assertAllClose(mapped_square(R, sigma=sigma_pair),
                                neighbor_square(R, nbrs, sigma=sigma))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype
            } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
    def test_triplet_no_species_scalar(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        angle_fn = lambda dR1, dR2: np.sum(np.square(dR1) + np.square(dR2))
        square = lambda dR: np.sum(np.square(dR))
        displacement, _ = space.free()
        metric = lambda Ra, Rb, **kwargs: \
            np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1)

        triplet_square = smap.triplet(angle_fn, displacement)
        metric = space.map_product(metric)

        count = PARTICLE_COUNT // 50

        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = random.uniform(split, (count, spatial_dimension), dtype=dtype)

            self.assertAllClose(
                triplet_square(R) / count / 2.,
                np.array(0.5 * np.sum(metric(R, R)), dtype=dtype))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype,
            } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
    def test_triplet_static_species_scalar(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)
        angle_fn = lambda dR1, dR2, param=5.0: param * np.sum(np.square(dR1))
        square = lambda dR, param: param * np.sum(np.square(dR))
        params = f32(np.array([[[1., 1.], [2., 0.]], [[0., 2.], [1., 1.]]]))

        count = PARTICLE_COUNT // 50
        key, split = random.split(key)
        species = random.randint(split, (count, ), 0, 2)
        displacement, _ = space.free()
        metric = lambda Ra, Rb, **kwargs: \
          np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1)
        triplet_square = smap.triplet(angle_fn,
                                      displacement,
                                      species=species,
                                      param=params,
                                      reduce_axis=None)

        metric = space.map_product(metric)
        for _ in range(STOCHASTIC_SAMPLES):
            key, split = random.split(key)
            R = random.uniform(split, (count, spatial_dimension), dtype=dtype)
            total = 0.
            for i in range(2):
                for j in range(2):
                    R_1 = R[species == i]
                    R_2 = R[species == j]
                    total += 0.5 * np.sum(metric(R_1, R_2))
            self.assertAllClose(
                triplet_square(R) / count, np.array(total, dtype=dtype))
Ejemplo n.º 15
0
class SMapTest(jtu.JaxTestCase):

  @parameterized.named_parameters(jtu.cases_from_list(
      {
          'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__),
          'spatial_dimension': dim,
          'dtype': dtype
      } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
  def test_bond_no_type_static(self, spatial_dimension, dtype):
    harmonic = lambda dr, **kwargs: (dr - f32(1)) ** f32(2)
    disp, _ = space.free()
    metric = space.metric(disp)

    mapped = smap.bond(harmonic, metric, np.array([[0, 1], [0, 2]], i32))

    key = random.PRNGKey(0)

    for _ in range(STOCHASTIC_SAMPLES):
      key, split = random.split(key)
      R = random.uniform(
        split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)

      accum = harmonic(metric(R[0], R[1])) + harmonic(metric(R[0], R[2]))

      self.assertAllClose(mapped(R), dtype(accum), True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {
          'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__),
          'spatial_dimension': dim,
          'dtype': dtype
      } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
  def test_bond_no_type_dynamic(self, spatial_dimension, dtype):
    harmonic = lambda dr, **kwargs: (dr - f32(1)) ** f32(2)
    disp, _ = space.free()
    metric = space.metric(disp)

    mapped = smap.bond(harmonic, metric)
    bonds = np.array([[0, 1], [0, 2]], i32)

    key = random.PRNGKey(0)

    for _ in range(STOCHASTIC_SAMPLES):
      key, split = random.split(key)
      R = random.uniform(
        split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)

      accum = harmonic(metric(R[0], R[1])) + harmonic(metric(R[0], R[2]))

      self.assertAllClose(mapped(R, bonds), dtype(accum), True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {
          'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__),
          'spatial_dimension': dim,
          'dtype': dtype
      } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
  def test_bond_type_static(self, spatial_dimension, dtype):
    harmonic = lambda dr, sigma, **kwargs: (dr - sigma) ** f32(2)
    disp, _ = space.free()
    metric = space.metric(disp)

    sigma = np.array([1.0, 2.0], f32)

    mapped = smap.bond(
      harmonic, metric,
      np.array([[0, 1], [0, 2]], i32), np.array([0, 1], i32), sigma=sigma)

    key = random.PRNGKey(0)

    for _ in range(STOCHASTIC_SAMPLES):
      key, split = random.split(key)
      R = random.uniform(
        split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)

      accum = harmonic(metric(R[0], R[1]), 1) + harmonic(metric(R[0], R[2]), 2)

      self.assertAllClose(mapped(R), dtype(accum), True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {
          'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__),
          'spatial_dimension': dim,
          'dtype': dtype
      } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
  def test_bond_type_dynamic(self, spatial_dimension, dtype):
    harmonic = lambda dr, sigma, **kwargs: (dr - sigma) ** f32(2)
    disp, _ = space.free()
    metric = space.metric(disp)

    sigma = np.array([1.0, 2.0], f32)

    mapped = smap.bond(harmonic, metric, sigma=sigma)
    bonds = np.array([[0, 1], [0, 2]], i32)
    bond_types = np.array([0, 1], i32)

    key = random.PRNGKey(0)

    for _ in range(STOCHASTIC_SAMPLES):
      key, split = random.split(key)
      R = random.uniform(
        split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)

      accum = harmonic(metric(R[0], R[1]), 1) + harmonic(metric(R[0], R[2]), 2)

      self.assertAllClose(mapped(R, bonds, bond_types), dtype(accum), True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {
          'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__),
          'spatial_dimension': dim,
          'dtype': dtype
      } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
  def test_bond_params_dynamic(self, spatial_dimension, dtype):
    harmonic = lambda dr, sigma, **kwargs: (dr - sigma) ** f32(2)
    disp, _ = space.free()
    metric = space.metric(disp)

    sigma = np.array([1.0, 2.0], f32)

    mapped = smap.bond(harmonic, metric)
    bonds = np.array([[0, 1], [0, 2]], i32)

    key = random.PRNGKey(0)

    for _ in range(STOCHASTIC_SAMPLES):
      key, split = random.split(key)
      R = random.uniform(
        split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)

      accum = harmonic(metric(R[0], R[1]), 1) + harmonic(metric(R[0], R[2]), 2)

      self.assertAllClose(mapped(R, bonds, sigma=sigma), dtype(accum), True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {
          'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__),
          'spatial_dimension': dim,
          'dtype': dtype
      } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
  def test_bond_per_bond_static(self, spatial_dimension, dtype):
    harmonic = lambda dr, sigma, **kwargs: (dr - sigma) ** f32(2)
    disp, _ = space.free()
    metric = space.metric(disp)

    sigma = np.array([1.0, 2.0], f32)

    mapped = smap.bond(
      harmonic, metric, np.array([[0, 1], [0, 2]], i32), sigma=sigma)

    key = random.PRNGKey(0)

    for _ in range(STOCHASTIC_SAMPLES):
      key, split = random.split(key)
      R = random.uniform(
        split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)

      accum = harmonic(metric(R[0], R[1]), 1) + harmonic(metric(R[0], R[2]), 2)

      self.assertAllClose(mapped(R), dtype(accum), True)

  def test_get_species_parameters(self):
    species = [(0, 0), (0, 1), (1, 0), (1, 1)]
    params = np.array([[2.0, 3.0], [3.0, 1.0]])
    global_params = 3.0
    self.assertAllClose(
        smap._get_species_parameters(params, species[0]), 2.0, True)
    self.assertAllClose(
        smap._get_species_parameters(params, species[1]), 3.0, True)
    self.assertAllClose(
        smap._get_species_parameters(params, species[2]), 3.0, True)
    self.assertAllClose(
        smap._get_species_parameters(params, species[3]), 1.0, True)
    for s in species:
      self.assertAllClose(
          smap._get_species_parameters(global_params, s), 3.0, True)

  def test_get_matrix_parameters(self):
    params = np.array([1.0, 2.0])
    params_mat_test = np.array([[1.0, 1.5], [1.5, 2.0]])
    params_mat = smap._get_matrix_parameters(params)
    self.assertAllClose(params_mat, params_mat_test, True)

    params_mat_direct = np.array([[1.0, 2.0], [3.0, 4.0]])
    self.assertAllClose(
        smap._get_matrix_parameters(params_mat_direct), params_mat_direct, True)

    params_scalar = 1.0
    self.assertAllClose(
        smap._get_matrix_parameters(params_scalar), params_scalar, True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {
          'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__),
          'spatial_dimension': dim,
          'dtype': dtype
      } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
  def test_pair_no_species_scalar(self, spatial_dimension, dtype):
    square = lambda dr: dr ** 2
    displacement, _ = space.free()
    metric = lambda Ra, Rb, **kwargs: \
        np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1)

    mapped_square = smap.pair(square, metric)
    metric = space.map_product(metric)

    key = random.PRNGKey(0)

    for _ in range(STOCHASTIC_SAMPLES):
      key, split = random.split(key)
      R = random.uniform(
        split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
      self.assertAllClose(
        mapped_square(R),
        np.array(0.5 * np.sum(square(metric(R, R))), dtype=dtype), True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {
          'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__),
          'spatial_dimension': dim,
          'dtype': dtype
      } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
  def test_pair_no_species_scalar_dynamic(self, spatial_dimension, dtype):
    square = lambda dr, epsilon: epsilon * dr ** 2
    displacement, _ = space.free()
    metric = lambda Ra, Rb, **kwargs: \
        np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1)

    mapped_square = smap.pair(square, metric)
    metric = space.map_product(metric)

    key = random.PRNGKey(0)
    for _ in range(STOCHASTIC_SAMPLES):
      key, split1, split2 = random.split(key, 3)
      R = random.uniform(
        split1, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
      epsilon = random.uniform(split2, (PARTICLE_COUNT,), dtype=dtype)
      mat_epsilon = 0.5 * (epsilon[:, np.newaxis] + epsilon[np.newaxis, :])
      self.assertAllClose(
        mapped_square(R, epsilon=epsilon),
        np.array(0.5 * np.sum(
          square(metric(R, R), mat_epsilon)), dtype=dtype), True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {
          'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__),
          'spatial_dimension': dim,
          'dtype': dtype
      } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
  def test_pair_no_species_vector(self, spatial_dimension, dtype):
    square = lambda dr: np.sum(dr ** 2, axis=2)
    disp, _ = space.free()

    mapped_square = smap.pair(square, disp)

    disp = space.map_product(disp)
    key = random.PRNGKey(0)

    for _ in range(STOCHASTIC_SAMPLES):
      key, split = random.split(key)
      R = random.uniform(
        split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
      mapped_ref = np.array(0.5 * np.sum(square(disp(R, R))), dtype=dtype)
      self.assertAllClose(mapped_square(R), mapped_ref, True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {
          'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__),
          'spatial_dimension': dim,
          'dtype': dtype,
      } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
  def test_pair_static_species_scalar(self, spatial_dimension, dtype):
    key = random.PRNGKey(0)

    square = lambda dr, param=1.0: param * dr ** 2
    params = f32(np.array([[1.0, 2.0], [2.0, 3.0]]))

    key, split = random.split(key)
    species = random.randint(split, (PARTICLE_COUNT,), 0, 2)
    displacement, _ = space.free()
    metric = lambda Ra, Rb, **kwargs: \
        np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1)

    mapped_square = smap.pair(
      square, metric, species=species, param=params)

    metric = space.map_product(metric)

    for _ in range(STOCHASTIC_SAMPLES):
      key, split = random.split(key)
      R = random.uniform(
        split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
      total = 0.0
      for i in range(2):
        for j in range(2):
          param = params[i, j]
          R_1 = R[species == i]
          R_2 = R[species == j]
          total = total + 0.5 * np.sum(square(metric(R_1, R_2), param))
      self.assertAllClose(mapped_square(R), np.array(total, dtype=dtype), True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {
          'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__),
          'spatial_dimension': dim,
          'dtype': dtype,
      } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
  def test_pair_static_species_scalar_dynamic(self, spatial_dimension, dtype):
    key = random.PRNGKey(0)

    square = lambda dr, param=1.0: param * dr ** 2

    key, split = random.split(key)
    species = random.randint(split, (PARTICLE_COUNT,), 0, 2)
    displacement, _ = space.free()
    metric = lambda Ra, Rb, **kwargs: \
        np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1)

    mapped_square = smap.pair(square, metric, species=species)

    metric = space.map_product(metric)

    for _ in range(STOCHASTIC_SAMPLES):
      key, split1, split2 = random.split(key, 3)
      R = random.uniform(
        split1, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
      params = random.uniform(
        split2, (2, 2), dtype=dtype)
      params = f32(0.5) * (params + params.T)
      total = 0.0
      for i in range(2):
        for j in range(2):
          param = params[i, j]
          R_1 = R[species == i]
          R_2 = R[species == j]
          total = total + 0.5 * np.sum(square(metric(R_1, R_2), param))
      self.assertAllClose(
        mapped_square(R, param=params), np.array(total, dtype=dtype), True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {
          'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__),
          'spatial_dimension': dim,
          'dtype': dtype,
      } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
  def test_pair_scalar_dummy_arg(
      self, spatial_dimension, dtype):
    key = random.PRNGKey(0)

    square = lambda dr, param=f32(1.0), **unused_kwargs: param * dr ** 2

    key, split = random.split(key)
    R = random.normal(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
    displacement, shift = space.free()

    mapped = smap.pair(square, space.metric(displacement))

    mapped(R, t=f32(0))

  @parameterized.named_parameters(jtu.cases_from_list(
      {
          'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__),
          'spatial_dimension': dim,
          'dtype': dtype
      } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
  def test_pair_static_species_vector(self, spatial_dimension, dtype):
    key = random.PRNGKey(0)

    square = lambda dr, param=1.0: param * np.sum(dr ** 2, axis=2)
    params = np.array([[1.0, 2.0], [2.0, 3.0]], dtype=f32)

    key, split = random.split(key)
    species = random.randint(split, (PARTICLE_COUNT,), 0, 2)
    disp, _ = space.free()

    mapped_square = smap.pair(square, disp, species=species, param=params)

    disp = space.map_product(disp)

    for _ in range(STOCHASTIC_SAMPLES):
      key, split = random.split(key)
      R = random.uniform(
        split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
      total = 0.0
      for i in range(2):
        for j in range(2):
          param = params[i, j]
          R_1 = R[species == i]
          R_2 = R[species == j]
          total = total + 0.5 * np.sum(square(disp(R_1, R_2), param))
      self.assertAllClose(mapped_square(R), np.array(total, dtype=dtype), True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {
          'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__),
          'spatial_dimension': dim,
          'dtype': dtype,
      } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
  def test_pair_dynamic_species_scalar(self, spatial_dimension, dtype):
    key = random.PRNGKey(0)

    square = lambda dr, param=1.0: param * dr ** 2
    params = f32(np.array([[1.0, 2.0], [2.0, 3.0]]))

    key, split = random.split(key)
    species = random.randint(split, (PARTICLE_COUNT,), 0, 2)
    displacement, _ = space.free()
    metric = lambda Ra, Rb, **kwargs: \
        np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1)

    mapped_square = smap.pair(
        square, metric, species=quantity.Dynamic, param=params)

    metric = space.map_product(metric)

    for _ in range(STOCHASTIC_SAMPLES):
      key, split = random.split(key)
      R = random.uniform(
        split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
      total = 0.0
      for i in range(2):
        for j in range(2):
          param = params[i, j]
          R_1 = R[species == i]
          R_2 = R[species == j]
          total = total + 0.5 * np.sum(square(metric(R_1, R_2), param))
      self.assertAllClose(
        mapped_square(R, species, 2), np.array(total, dtype=dtype), True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {
          'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__),
          'spatial_dimension': dim,
          'dtype': dtype
      } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
  def test_pair_dynamic_species_vector(
      self, spatial_dimension, dtype):
    key = random.PRNGKey(0)

    square = lambda dr, param=1.0: param * np.sum(dr ** 2, axis=2)
    params = f32(np.array([[1.0, 2.0], [2.0, 3.0]]))

    key, split = random.split(key)
    species = random.randint(split, (PARTICLE_COUNT,), 0, 2)
    disp, _ = space.free()

    mapped_square = smap.pair(
        square, disp, species=quantity.Dynamic, param=params)

    disp = vmap(vmap(disp, (0, None), 0), (None, 0), 0)

    for _ in range(STOCHASTIC_SAMPLES):
      key, split = random.split(key)
      R = random.uniform(
        split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype)
      total = 0.0
      for i in range(2):
        for j in range(2):
          param = params[i, j]
          R_1 = R[species == i]
          R_2 = R[species == j]
          total = total + 0.5 * np.sum(square(disp(R_1, R_2), param))
      self.assertAllClose(
        mapped_square(R, species, 2), np.array(total, dtype=dtype), True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {
          'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__),
          'spatial_dimension': dim,
          'dtype': dtype
      } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
  def test_pair_neighbor_list_scalar(
      self, spatial_dimension, dtype):
    key = random.PRNGKey(0)

    def truncated_square(dr, sigma):
      return np.where(dr < sigma, dr ** 2, f32(0.))

    tol = 2e-10 if dtype == np.float32 else None

    N = NEIGHBOR_LIST_PARTICLE_COUNT
    box_size = 4. * N ** (1. / spatial_dimension)

    key, split = random.split(key)
    disp, _ = space.periodic(box_size)
    d = space.metric(disp)

    neighbor_square = jit(smap.pair_neighbor_list(truncated_square, d))
    mapped_square = jit(smap.pair(truncated_square, d))

    for _ in range(STOCHASTIC_SAMPLES):
      key, split = random.split(key)
      R = box_size * random.uniform(
        split, (N, spatial_dimension), dtype=dtype)
      sigma = random.uniform(key, (), minval=0.5, maxval=2.5)
      neighbor_fn = jit(partition.neighbor_list(disp, box_size, sigma, R))
      idx = neighbor_fn(R)
      self.assertAllClose(mapped_square(R, sigma=sigma),
                          neighbor_square(R, idx, sigma=sigma), True, tol, tol)

  @parameterized.named_parameters(jtu.cases_from_list(
      {
          'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__),
          'spatial_dimension': dim,
          'dtype': dtype
      } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
  def test_pair_neighbor_list_scalar_params_no_species(
      self, spatial_dimension, dtype):
    key = random.PRNGKey(0)

    def truncated_square(dr, sigma):
      return np.where(dr < sigma, dr ** 2, f32(0.))

    tol = 2e-10 if dtype == np.float32 else None

    N = NEIGHBOR_LIST_PARTICLE_COUNT
    box_size = 2. * N ** (1. / spatial_dimension)

    key, split = random.split(key)
    disp, _ = space.periodic(box_size)
    d = space.metric(disp)

    neighbor_square = jit(smap.pair_neighbor_list(truncated_square, d))
    mapped_square = jit(smap.pair(truncated_square, d))

    for _ in range(STOCHASTIC_SAMPLES):
      key, split = random.split(key)
      R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype)
      sigma = random.uniform(key, (N,), minval=0.5, maxval=1.5)
      neighbor_fn = jit(
        partition.neighbor_list(disp, box_size, np.max(sigma), R))
      idx = neighbor_fn(R)
      self.assertAllClose(mapped_square(R, sigma=sigma),
                          neighbor_square(R, idx, sigma=sigma), True, tol, tol)

  @parameterized.named_parameters(jtu.cases_from_list(
      {
          'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__),
          'spatial_dimension': dim,
          'dtype': dtype
      } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
  def test_pair_neighbor_list_scalar_params_matrix(
      self, spatial_dimension, dtype):
    key = random.PRNGKey(0)

    def truncated_square(dr, sigma):
      return np.where(dr < sigma, dr ** 2, f32(0.))

    tol = 2e-10 if dtype == np.float32 else None

    N = NEIGHBOR_LIST_PARTICLE_COUNT
    box_size = 2. * N ** (1. / spatial_dimension)

    key, split = random.split(key)
    disp, _ = space.periodic(box_size)
    d = space.metric(disp)

    neighbor_square = jit(smap.pair_neighbor_list(truncated_square, d))
    mapped_square = jit(smap.pair(truncated_square, d))

    for _ in range(STOCHASTIC_SAMPLES):
      key, split = random.split(key)
      R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype)
      sigma = random.uniform(key, (N, N), minval=0.5, maxval=1.5)
      sigma = 0.5 * (sigma + sigma.T)
      neighbor_fn = jit(
        partition.neighbor_list(disp, box_size, np.max(sigma), R))
      idx = neighbor_fn(R)
      self.assertAllClose(mapped_square(R, sigma=sigma),
                          neighbor_square(R, idx, sigma=sigma), True, tol, tol)

  @parameterized.named_parameters(jtu.cases_from_list(
      {
          'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__),
          'spatial_dimension': dim,
          'dtype': dtype
      } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
  def test_pair_neighbor_list_scalar_params_species(
      self, spatial_dimension, dtype):
    key = random.PRNGKey(0)

    def truncated_square(dr, sigma):
      return np.where(dr < sigma, dr ** 2, f32(0.))

    tol = 2e-10 if dtype == np.float32 else None

    N = NEIGHBOR_LIST_PARTICLE_COUNT
    box_size = 2. * N ** (1. / spatial_dimension)
    species = np.zeros((N,), np.int32)
    species = np.where(np.arange(N) > N / 3, 1, species)
    species = np.where(np.arange(N) > 2 * N / 3, 2, species)

    key, split = random.split(key)
    disp, _ = space.periodic(box_size)
    d = space.metric(disp)

    neighbor_square = jit(
      smap.pair_neighbor_list(truncated_square, d, species=species))
    mapped_square = jit(smap.pair(truncated_square, d, species=species))

    for _ in range(STOCHASTIC_SAMPLES):
      key, split = random.split(key)
      R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype)
      sigma = random.uniform(key, (3, 3), minval=0.5, maxval=1.5)
      sigma = 0.5 * (sigma + sigma.T)
      neighbor_fn = jit(
        partition.neighbor_list(disp, box_size, np.max(sigma), R))
      idx = neighbor_fn(R)
      self.assertAllClose(mapped_square(R, sigma=sigma),
                          neighbor_square(R, idx, sigma=sigma), True, tol, tol)

  @parameterized.named_parameters(jtu.cases_from_list(
      {
          'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__),
          'spatial_dimension': dim,
          'dtype': dtype
      } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
  def test_pair_neighbor_list_vector(self, spatial_dimension, dtype):
    key = random.PRNGKey(0)

    def truncated_square(dR, sigma):
      dr = np.reshape(space.distance(dR), dR.shape[:-1] + (1,))
      return np.where(dr < sigma, dR ** 2, f32(0.))

    tol = 5e-6 if dtype == np.float32 else 1e-14

    N = PARTICLE_COUNT
    box_size = 2. * N ** (1. / spatial_dimension)

    key, split = random.split(key)
    disp, _ = space.periodic(box_size)

    neighbor_square = jit(smap.pair_neighbor_list(
      truncated_square, disp, reduce_axis=(1,)))
    mapped_square = jit(smap.pair(truncated_square, disp, reduce_axis=(1,)))

    for _ in range(STOCHASTIC_SAMPLES):
      key, split = random.split(key)
      R = box_size * random.uniform(
        split, (N, spatial_dimension), dtype=dtype)
      sigma = random.uniform(key, (), minval=0.5, maxval=1.5)
      neighbor_fn = jit(partition.neighbor_list(disp, box_size, sigma, R))
      idx = neighbor_fn(R)
      self.assertAllClose(mapped_square(R, sigma=sigma),
                          neighbor_square(R, idx, sigma=sigma), True, tol, tol)
Ejemplo n.º 16
0
class BatchTest(jtu.JaxTestCase):
    @jtu.parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name':
            '_train_shape={}_test_shape={}_network={}_{}_batch_size={}'.format(
                train, test, network, name, batch_size),
            'train_shape':
            train,
            'test_shape':
            test,
            'network':
            network,
            'name':
            name,
            'kernel_fn':
            kernel_fn,
            'batch_size':
            batch_size
        } for train, test, network in zip(TRAIN_SHAPES, TEST_SHAPES, NETWORK)
                            for name, kernel_fn in KERNELS.items()
                            for batch_size in [2, 8]))
    def testSerial(self, train_shape, test_shape, network, name, kernel_fn,
                   batch_size):
        key = random.PRNGKey(0)
        key, self_split, other_split = random.split(key, 3)
        data_self = random.normal(self_split, train_shape)
        data_other = random.normal(other_split, test_shape)
        kernel_fn = kernel_fn(key, train_shape[1:], network)
        kernel_batched = batch._serial(kernel_fn, batch_size=batch_size)

        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)

    # We also exclude tests for dropout + parallel. It is not clear what is the
    # best way to handle this case.
    @jtu.parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name':
            '_train_shape={}_test_shape={}_network={}_{}'.format(
                train, test, network, name),
            'train_shape':
            train,
            'test_shape':
            test,
            'network':
            network,
            'name':
            name,
            'kernel_fn':
            kernel_fn
        } for train, test, network in zip(TRAIN_SHAPES, TEST_SHAPES, NETWORK)
                            for name, kernel_fn in KERNELS.items()))
    def testParallel(self, train_shape, test_shape, network, name, kernel_fn):
        utils.stub_out_pmap(batch, 2)
        key = random.PRNGKey(0)
        key, self_split, other_split = random.split(key, 3)
        data_self = random.normal(self_split, train_shape)
        data_other = random.normal(other_split, test_shape)

        kernel_fn = kernel_fn(key, train_shape[1:], network, use_dropout=False)
        kernel_batched = batch._parallel(kernel_fn)

        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other, True)

    @jtu.parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name':
            '_train_shape={}_test_shape={}_network={}_{}_batch_size={}'.format(
                train, test, network, name, batch_size),
            'train_shape':
            train,
            'test_shape':
            test,
            'network':
            network,
            'name':
            name,
            'kernel_fn':
            kernel_fn,
            'batch_size':
            batch_size
        } for train, test, network in zip(TRAIN_SHAPES, TEST_SHAPES, NETWORK)
                            for name, kernel_fn in KERNELS.items()
                            for batch_size in [2, 8]))
    def testComposition(self, train_shape, test_shape, network, name,
                        kernel_fn, batch_size):
        utils.stub_out_pmap(batch, 2)

        key = random.PRNGKey(0)
        key, self_split, other_split = random.split(key, 3)
        data_self = random.normal(self_split, train_shape)
        data_other = random.normal(other_split, test_shape)

        kernel_fn = kernel_fn(key, train_shape[1:], network)

        kernel_batched = batch._parallel(
            batch._serial(kernel_fn, batch_size=batch_size))
        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)

        kernel_batched = batch._serial(batch._parallel(kernel_fn),
                                       batch_size=batch_size)
        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)

    @jtu.parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name':
            '_train_shape={}_test_shape={}_network={}_{}_batch_size={}'.format(
                train, test, network, name, batch_size),
            'train_shape':
            train,
            'test_shape':
            test,
            'network':
            network,
            'name':
            name,
            'kernel_fn':
            kernel_fn,
            'batch_size':
            batch_size
        } for train, test, network in zip(TRAIN_SHAPES, TEST_SHAPES, NETWORK)
                            for name, kernel_fn in KERNELS.items()
                            for batch_size in [2, 8]))
    def testAutomatic(self, train_shape, test_shape, network, name, kernel_fn,
                      batch_size):
        utils.stub_out_pmap(batch, 2)

        key = random.PRNGKey(0)
        key, self_split, other_split = random.split(key, 3)
        data_self = random.normal(self_split, train_shape)
        data_other = random.normal(other_split, test_shape)

        kernel_fn = kernel_fn(key, train_shape[1:], network)

        kernel_batched = batch.batch(kernel_fn, batch_size=batch_size)
        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)

        kernel_batched = batch.batch(kernel_fn,
                                     batch_size=batch_size,
                                     store_on_device=False)
        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)

    def _test_analytic_kernel_composition(self, batching_fn):
        # Check Fully-Connected.
        rng = random.PRNGKey(0)
        rng_self, rng_other = random.split(rng)
        x_self = random.normal(rng_self, (8, 10))
        x_other = random.normal(rng_other, (2, 10))
        Block = stax.serial(stax.Dense(256), stax.Relu())

        _, _, ker_fn = Block
        ker_fn = batching_fn(ker_fn)

        _, _, composed_ker_fn = stax.serial(Block, Block)

        ker_out = ker_fn(ker_fn(x_self))
        composed_ker_out = composed_ker_fn(x_self)
        if batching_fn == batch._parallel:
            # In the parallel setting, `x1_is_x2` is not computed correctly
            # when x1==x2.
            composed_ker_out = composed_ker_out._replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out, True)

        ker_out = ker_fn(ker_fn(x_self, x_other))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out._replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out, True)

        # Check convolutional + pooling.
        x_self = random.normal(rng, (8, 10, 10, 3))
        x_other = random.normal(rng, (2, 10, 10, 3))

        Block = stax.serial(stax.Conv(256, (2, 2)), stax.Relu())
        Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10))

        block_ker_fn, readout_ker_fn = Block[2], Readout[2]
        _, _, composed_ker_fn = stax.serial(Block, Readout)
        block_ker_fn = batching_fn(block_ker_fn)
        readout_ker_fn = batching_fn(readout_ker_fn)

        ker_out = readout_ker_fn(block_ker_fn(x_self, marginalization='none'))
        composed_ker_out = composed_ker_fn(x_self)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out._replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out, True)
        ker_out = readout_ker_fn(
            block_ker_fn(x_self, x_other, marginalization='none'))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out._replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out, True)

    @jtu.parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name':
            '_on_device={}_batch_size={}'.format(store_on_device, batch_size),
            'store_on_device':
            store_on_device,
            'batch_size':
            batch_size
        } for store_on_device in [True, False] for batch_size in [2, 8]))
    def testAnalyticKernelComposeSerial(self, store_on_device, batch_size):
        self._test_analytic_kernel_composition(
            partial(batch._serial,
                    batch_size=batch_size,
                    store_on_device=store_on_device))

    def testAnalyticKernelComposeParallel(self):
        utils.stub_out_pmap(batch, 2)
        self._test_analytic_kernel_composition(batch._parallel)

    @jtu.parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name':
            '_on_device={}_batch_size={}'.format(store_on_device, batch_size),
            'store_on_device':
            store_on_device,
            'batch_size':
            batch_size
        } for store_on_device in [True, False] for batch_size in [2, 8]))
    def testAnalyticKernelComposeAutomatic(self, store_on_device, batch_size):
        utils.stub_out_pmap(batch, 2)
        self._test_analytic_kernel_composition(
            partial(batch.batch,
                    batch_size=batch_size,
                    store_on_device=store_on_device))

    def test_jit_or_pmap_broadcast(self):
        def kernel_fn(x1,
                      x2,
                      do_flip,
                      keys,
                      do_square,
                      params,
                      _unused=None,
                      p=0.65):
            res = np.abs(np.matmul(x1, x2))
            if do_square:
                res *= res
            if do_flip:
                res = -res

            res *= random.uniform(keys) * p
            return [res, params]

        params = (np.array([1., 0.3]), (np.array([1.2]), np.array([0.5])))
        x2 = np.arange(0, 10).reshape((10, ))
        keys = random.PRNGKey(1)

        kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn,
                                                         device_count=0)
        x1 = np.arange(0, 10).reshape((1, 10))
        for do_flip in [True, False]:
            for do_square in [True, False]:
                with self.subTest(do_flip=do_flip,
                                  do_square=do_square,
                                  device_count=0):
                    res_1 = kernel_fn(x1,
                                      x2,
                                      do_flip,
                                      keys,
                                      do_square,
                                      params,
                                      _unused=True,
                                      p=0.65)
                    res_2 = kernel_fn_pmapped(x1,
                                              x2,
                                              do_flip,
                                              keys,
                                              do_square,
                                              params,
                                              _unused=True)
                    self.assertAllClose(res_1, res_2, True)

        utils.stub_out_pmap(batch, 1)
        x1 = np.arange(0, 10).reshape((1, 10))
        kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn,
                                                         device_count=1)
        for do_flip in [True, False]:
            for do_square in [True, False]:
                with self.subTest(do_flip=do_flip,
                                  do_square=do_square,
                                  device_count=1):
                    res_1 = kernel_fn(x1,
                                      x2,
                                      do_flip,
                                      keys,
                                      do_square,
                                      params,
                                      _unused=False,
                                      p=0.65)
                    res_2 = kernel_fn_pmapped(x1,
                                              x2,
                                              do_flip,
                                              keys,
                                              do_square,
                                              params,
                                              _unused=None)
                    self.assertAllClose(res_1[0], res_2[0], True)
                    self.assertAllClose(
                        tree_map(partial(np.expand_dims, axis=0), res_1[1]),
                        res_2[1], True)

        kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn,
                                                         device_count=2)
        x1 = np.arange(0, 20).reshape((2, 10))
        utils.stub_out_pmap(batch, 2)

        def broadcast(arg):
            return np.broadcast_to(arg, (2, ) + arg.shape)

        for do_flip in [True, False]:
            for do_square in [True, False]:
                with self.subTest(do_flip=do_flip,
                                  do_square=do_square,
                                  device_count=2):
                    res_1 = kernel_fn(x1,
                                      x2,
                                      do_flip,
                                      keys,
                                      do_square,
                                      params,
                                      p=0.2)
                    res_2 = kernel_fn_pmapped(x1,
                                              x2,
                                              do_flip,
                                              keys,
                                              do_square,
                                              params,
                                              _unused=None,
                                              p=0.2)
                    self.assertAllClose(res_1[0][0], res_2[0][0], True)
                    self.assertAllClose(res_1[0][1], res_2[0][1], True)
                    self.assertAllClose(tree_map(broadcast, res_1[1]),
                                        res_2[1], True)
Ejemplo n.º 17
0
class ImageTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_shape={}_target={}_method={}_antialias={}".format(
                    jtu.format_shape_dtype_string(image_shape, dtype),
                    jtu.format_shape_dtype_string(target_shape, dtype), method,
                    antialias),
                "dtype":
                dtype,
                "image_shape":
                image_shape,
                "target_shape":
                target_shape,
                "method":
                method,
                "antialias":
                antialias
            } for dtype in float_dtypes
            for target_shape, image_shape in itertools.
            combinations_with_replacement([[2, 3, 2, 4], [2, 6, 4, 4],
                                           [2, 33, 17, 4], [2, 50, 38, 4]], 2)
            for method in
            ["nearest", "bilinear", "lanczos3", "lanczos5", "bicubic"]
            for antialias in [False, True]))
    @unittest.skipIf(not tf, "Test requires TensorFlow")
    def testResizeAgainstTensorFlow(self, dtype, image_shape, target_shape,
                                    method, antialias):
        # TODO(phawkins): debug this. There is a small mismatch between TF and JAX
        # for some cases of non-antialiased bicubic downscaling; we would expect
        # exact equality.
        if method == "bicubic" and any(
                x < y for x, y in zip(target_shape, image_shape)):
            raise unittest.SkipTest(
                "non-antialiased bicubic downscaling mismatch")
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: (rng(image_shape, dtype), )

        def tf_fn(x):
            out = tf.image.resize(x.astype(np.float64),
                                  tf.constant(target_shape[1:-1]),
                                  method=method,
                                  antialias=antialias).numpy().astype(dtype)
            return out

        jax_fn = partial(image.resize,
                         shape=target_shape,
                         method=method,
                         antialias=antialias)
        self._CheckAgainstNumpy(tf_fn,
                                jax_fn,
                                args_maker,
                                check_dtypes=True,
                                tol={
                                    np.float16: 2e-2,
                                    jnp.bfloat16: 1e-1,
                                    np.float32: 1e-4,
                                    np.float64: 1e-4
                                })

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_target={}_method={}".format(
                jtu.format_shape_dtype_string(image_shape, dtype),
                jtu.format_shape_dtype_string(target_shape, dtype), method),
            "dtype":
            dtype,
            "image_shape":
            image_shape,
            "target_shape":
            target_shape,
            "method":
            method
        } for dtype in [np.float32] for target_shape, image_shape in
                            itertools.combinations_with_replacement(
                                [[3, 2], [6, 4], [33, 17], [50, 39]], 2)
                            for method in
                            ["nearest", "bilinear", "lanczos3", "bicubic"]))
    @unittest.skipIf(not PIL_Image, "Test requires PIL")
    def testResizeAgainstPIL(self, dtype, image_shape, target_shape, method):
        rng = jtu.rand_uniform(self.rng())
        args_maker = lambda: (rng(image_shape, dtype), )

        def pil_fn(x):
            pil_methods = {
                "nearest": PIL_Image.NEAREST,
                "bilinear": PIL_Image.BILINEAR,
                "bicubic": PIL_Image.BICUBIC,
                "lanczos3": PIL_Image.LANCZOS,
            }
            img = PIL_Image.fromarray(x.astype(np.float32))
            out = np.asarray(img.resize(target_shape[::-1],
                                        pil_methods[method]),
                             dtype=dtype)
            return out

        jax_fn = partial(image.resize,
                         shape=target_shape,
                         method=method,
                         antialias=True)
        self._CheckAgainstNumpy(pil_fn, jax_fn, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_target={}_method={}".format(
                jtu.format_shape_dtype_string(image_shape, dtype),
                jtu.format_shape_dtype_string(target_shape, dtype), method),
            "dtype":
            dtype,
            "image_shape":
            image_shape,
            "target_shape":
            target_shape,
            "method":
            method
        } for dtype in inexact_dtypes for image_shape, target_shape in [
            ([3, 1, 2], [6, 1, 4]),
            ([1, 3, 2, 1], [1, 6, 4, 1]),
        ] for method in ["nearest", "linear", "lanczos3", "lanczos5", "cubic"])
    )
    def testResizeUp(self, dtype, image_shape, target_shape, method):
        data = [64, 32, 32, 64, 50, 100]
        expected_data = {}
        expected_data["nearest"] = [
            64.0, 64.0, 32.0, 32.0, 64.0, 64.0, 32.0, 32.0, 32.0, 32.0, 64.0,
            64.0, 32.0, 32.0, 64.0, 64.0, 50.0, 50.0, 100.0, 100.0, 50.0, 50.0,
            100.0, 100.0
        ]
        expected_data["linear"] = [
            64.0, 56.0, 40.0, 32.0, 56.0, 52.0, 44.0, 40.0, 40.0, 44.0, 52.0,
            56.0, 36.5, 45.625, 63.875, 73.0, 45.5, 56.875, 79.625, 91.0, 50.0,
            62.5, 87.5, 100.0
        ]
        expected_data["lanczos3"] = [
            75.8294, 59.6281, 38.4313, 22.23, 60.6851, 52.0037, 40.6454,
            31.964, 35.8344, 41.0779, 47.9383, 53.1818, 24.6968, 43.0769,
            67.1244, 85.5045, 35.7939, 56.4713, 83.5243, 104.2017, 44.8138,
            65.1949, 91.8603, 112.2413
        ]
        expected_data["lanczos5"] = [
            77.5699, 60.0223, 40.6694, 23.1219, 61.8253, 51.2369, 39.5593,
            28.9709, 35.7438, 40.8875, 46.5604, 51.7041, 21.5942, 43.5299,
            67.7223, 89.658, 32.1213, 56.784, 83.984, 108.6467, 44.5802,
            66.183, 90.0082, 111.6109
        ]
        expected_data["cubic"] = [
            70.1453, 59.0252, 36.9748, 25.8547, 59.3195, 53.3386, 41.4789,
            35.4981, 36.383, 41.285, 51.0051, 55.9071, 30.2232, 42.151,
            65.8032, 77.731, 41.6492, 55.823, 83.9288, 98.1026, 47.0363,
            62.2744, 92.4903, 107.7284
        ]
        x = np.array(data, dtype=dtype).reshape(image_shape)
        output = image.resize(x, target_shape, method)
        expected = np.array(expected_data[method],
                            dtype=dtype).reshape(target_shape)
        self.assertAllClose(output, expected, atol=1e-04)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_shape={}_target={}_method={}_antialias={}".format(
                    jtu.format_shape_dtype_string(image_shape, dtype),
                    jtu.format_shape_dtype_string(target_shape, dtype), method,
                    antialias),
                "dtype":
                dtype,
                "image_shape":
                image_shape,
                "target_shape":
                target_shape,
                "method":
                method,
                "antialias":
                antialias
            } for dtype in [np.float32]
            for target_shape, image_shape in itertools.
            combinations_with_replacement([[2, 3, 2, 4], [2, 6, 4, 4],
                                           [2, 33, 17, 4], [2, 50, 38, 4]], 2)
            for method in ["bilinear", "lanczos3", "lanczos5", "bicubic"]
            for antialias in [False, True]))
    def testResizeGradients(self, dtype, image_shape, target_shape, method,
                            antialias):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: (rng(image_shape, dtype), )
        jax_fn = partial(image.resize,
                         shape=target_shape,
                         method=method,
                         antialias=antialias)
        jtu.check_grads(jax_fn, args_maker(), order=2, rtol=1e-2, eps=1.)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_target={}_method={}".format(
                jtu.format_shape_dtype_string(image_shape, dtype),
                jtu.format_shape_dtype_string(target_shape, dtype), method),
            "dtype":
            dtype,
            "image_shape":
            image_shape,
            "target_shape":
            target_shape,
            "scale":
            scale,
            "translation":
            translation,
            "method":
            method
        } for dtype in inexact_dtypes
                            for image_shape, target_shape, scale, translation
                            in [([3, 1, 2], [6, 1, 4], [2.0, 1.0, 2.0],
                                 [1.0, 0.0, -1.0]),
                                ([1, 3, 2, 1], [1, 6, 4, 1],
                                 [1.0, 2.0, 2.0, 1.0], [0.0, 1.0, -1.0, 0.0])]
                            for method in
                            ["linear", "lanczos3", "lanczos5", "cubic"]))
    def testScaleAndTranslateUp(self, dtype, image_shape, target_shape, scale,
                                translation, method):
        data = [64, 32, 32, 64, 50, 100]
        # Note zeros occur in the output because the sampling location is outside
        # the boundaries of the input image.
        expected_data = {}
        expected_data["linear"] = [
            0.0, 0.0, 0.0, 0.0, 56.0, 40.0, 32.0, 0.0, 52.0, 44.0, 40.0, 0.0,
            44.0, 52.0, 56.0, 0.0, 45.625, 63.875, 73.0, 0.0, 56.875, 79.625,
            91.0, 0.0
        ]
        expected_data["lanczos3"] = [
            0.0, 0.0, 0.0, 0.0, 59.6281, 38.4313, 22.23, 0.0, 52.0037, 40.6454,
            31.964, 0.0, 41.0779, 47.9383, 53.1818, 0.0, 43.0769, 67.1244,
            85.5045, 0.0, 56.4713, 83.5243, 104.2017, 0.0
        ]
        expected_data["lanczos5"] = [
            0.0, 0.0, 0.0, 0.0, 60.0223, 40.6694, 23.1219, 0.0, 51.2369,
            39.5593, 28.9709, 0.0, 40.8875, 46.5604, 51.7041, 0.0, 43.5299,
            67.7223, 89.658, 0.0, 56.784, 83.984, 108.6467, 0.0
        ]
        expected_data["cubic"] = [
            0.0, 0.0, 0.0, 0.0, 59.0252, 36.9748, 25.8547, 0.0, 53.3386,
            41.4789, 35.4981, 0.0, 41.285, 51.0051, 55.9071, 0.0, 42.151,
            65.8032, 77.731, 0.0, 55.823, 83.9288, 98.1026, 0.0
        ]
        x = np.array(data, dtype=dtype).reshape(image_shape)
        # Should we test different float types here?
        scale_a = jnp.array(scale, dtype=jnp.float32)
        translation_a = jnp.array(translation, dtype=jnp.float32)
        output = image.scale_and_translate(x, target_shape,
                                           range(len(image_shape)), scale_a,
                                           translation_a, method)

        expected = np.array(expected_data[method],
                            dtype=dtype).reshape(target_shape)
        self.assertAllClose(output, expected, atol=2e-03)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_dtype={}_method={}_antialias={}".format(
                    jtu.dtype_str(dtype), method, antialias),
                "dtype":
                dtype,
                "method":
                method,
                "antialias":
                antialias
            } for dtype in inexact_dtypes
            for method in ["linear", "lanczos3", "lanczos5", "cubic"]
            for antialias in [True, False]))
    def testScaleAndTranslateDown(self, dtype, method, antialias):
        image_shape = [1, 6, 7, 1]
        target_shape = [1, 3, 3, 1]

        data = [
            51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25,
            92, 41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90,
            43, 14, 89, 71, 32, 23, 23, 35, 93
        ]
        if antialias:
            expected_data = {}
            expected_data["linear"] = [
                43.5372, 59.3694, 53.6907, 49.3221, 56.8168, 55.4849, 0, 0, 0
            ]
            expected_data["lanczos3"] = [
                43.2884, 57.9091, 54.6439, 48.5856, 58.2427, 53.7551, 0, 0, 0
            ]
            expected_data["lanczos5"] = [
                43.9209, 57.6360, 54.9575, 48.9272, 58.1865, 53.1948, 0, 0, 0
            ]
            expected_data["cubic"] = [
                42.9935, 59.1687, 54.2138, 48.2640, 58.2678, 54.4088, 0, 0, 0
            ]
        else:
            expected_data = {}
            expected_data["linear"] = [
                43.6071, 89, 59, 37.1785, 27.2857, 58.3571, 0, 0, 0
            ]
            expected_data["lanczos3"] = [
                44.1390, 87.8786, 63.3111, 25.1161, 20.8795, 53.6165, 0, 0, 0
            ]
            expected_data["lanczos5"] = [
                44.8835, 85.5896, 66.7231, 16.9983, 19.8891, 47.1446, 0, 0, 0
            ]
            expected_data["cubic"] = [
                43.6426, 88.8854, 60.6638, 31.4685, 22.1204, 58.3457, 0, 0, 0
            ]
        x = np.array(data, dtype=dtype).reshape(image_shape)

        expected = np.array(expected_data[method],
                            dtype=dtype).reshape(target_shape)
        scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32)
        translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32)

        output = image.scale_and_translate(x,
                                           target_shape, (0, 1, 2, 3),
                                           scale_a,
                                           translation_a,
                                           method,
                                           antialias=antialias)
        self.assertAllClose(output, expected, atol=2e-03)

        # Tests that running with just a subset of dimensions that have non-trivial
        # scale and translation.
        output = image.scale_and_translate(x,
                                           target_shape, (1, 2),
                                           scale_a[1:3],
                                           translation_a[1:3],
                                           method,
                                           antialias=antialias)
        self.assertAllClose(output, expected, atol=2e-03)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "antialias={}".format(antialias),
            "antialias": antialias
        } for antialias in [True, False]))
    def testScaleAndTranslateJITs(self, antialias):
        image_shape = [1, 6, 7, 1]
        target_shape = [1, 3, 3, 1]

        data = [
            51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25,
            92, 41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90,
            43, 14, 89, 71, 32, 23, 23, 35, 93
        ]
        if antialias:
            expected_data = [
                43.5372, 59.3694, 53.6907, 49.3221, 56.8168, 55.4849, 0, 0, 0
            ]
        else:
            expected_data = [
                43.6071, 89, 59, 37.1785, 27.2857, 58.3571, 0, 0, 0
            ]
        x = jnp.array(data, dtype=jnp.float32).reshape(image_shape)

        expected = jnp.array(expected_data,
                             dtype=jnp.float32).reshape(target_shape)
        scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32)
        translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32)

        def jit_fn(in_array, s, t):
            return jax.image.scale_and_translate(
                in_array,
                target_shape, (0, 1, 2, 3),
                s,
                t,
                "linear",
                antialias,
                precision=jax.lax.Precision.HIGHEST)

        output = jax.jit(jit_fn)(x, scale_a, translation_a)
        self.assertAllClose(output, expected, atol=2e-03)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "antialias={}".format(antialias),
            "antialias": antialias
        } for antialias in [True, False]))
    def testScaleAndTranslateGradFinite(self, antialias):
        image_shape = [1, 6, 7, 1]
        target_shape = [1, 3, 3, 1]

        data = [
            51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25,
            92, 41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90,
            43, 14, 89, 71, 32, 23, 23, 35, 93
        ]

        x = jnp.array(data, dtype=jnp.float32).reshape(image_shape)
        scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32)
        translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32)

        def scale_fn(s):
            return jnp.sum(
                jax.image.scale_and_translate(
                    x,
                    target_shape, (0, 1, 2, 3),
                    s,
                    translation_a,
                    "linear",
                    antialias,
                    precision=jax.lax.Precision.HIGHEST))

        scale_out = jax.grad(scale_fn)(scale_a)
        self.assertTrue(jnp.all(jnp.isfinite(scale_out)))

        def translate_fn(t):
            return jnp.sum(
                jax.image.scale_and_translate(
                    x,
                    target_shape, (0, 1, 2, 3),
                    scale_a,
                    t,
                    "linear",
                    antialias,
                    precision=jax.lax.Precision.HIGHEST))

        translate_out = jax.grad(translate_fn)(translation_a)
        self.assertTrue(jnp.all(jnp.isfinite(translate_out)))
Ejemplo n.º 18
0
class Jax2TfTest(tf_test_util.JaxToTfTestCase):
    def test_basics(self):
        f_jax = lambda x: jnp.sin(jnp.cos(x))
        _, res_tf = self.ConvertAndCompare(f_jax, jnp.float_(0.7))

    def test_input_output_naming(self):
        @jax2tf.convert
        def f(xs, y):
            return [jnp.add(x, y) for x in xs]

        @tf.function(autograph=False)
        def u(xs, y):
            xs = tf.nest.map_structure(tf.convert_to_tensor, xs)
            with tf.GradientTape() as tape:
                tf.nest.map_structure(tape.watch, xs)
                y = f(xs, y)
                tape.gradient(y, xs)
                return y

        cf = u.get_concrete_function([1., 2., 3.], 4.)
        g = cf.graph
        g.get_operation_by_name("jax2tf_arg_0")
        g.get_operation_by_name("jax2tf_arg_0_1")
        g.get_operation_by_name("jax2tf_arg_0_2")
        g.get_operation_by_name("jax2tf_arg_1")
        g.get_operation_by_name("jax2tf_out")
        g.get_operation_by_name("jax2tf_out_1")
        g.get_operation_by_name("jax2tf_out_2")
        with self.assertRaises(KeyError):
            g.get_operation_by_name("jax2tf_arg_2")
        with self.assertRaises(KeyError):
            g.get_operation_by_name("jax2tf_out_3")
        g.get_operation_by_name("jax2tf_vjp/jax2tf_arg_0")
        g.get_operation_by_name("jax2tf_vjp/jax2tf_arg_1")
        g.get_operation_by_name("jax2tf_vjp/jax2tf_arg_1_1")
        g.get_operation_by_name("jax2tf_vjp/jax2tf_arg_1_2")
        g.get_operation_by_name("jax2tf_vjp/jax2tf_out")
        g.get_operation_by_name("jax2tf_vjp/jax2tf_out_1")
        g.get_operation_by_name("jax2tf_vjp/jax2tf_out_2")
        g.get_operation_by_name("jax2tf_vjp/jax2tf_out_3")

    def test_pytrees(self):
        # Take and return pytrees
        def f_jax(
            x: Tuple[float, Dict[str,
                                 float]]) -> Tuple[float, Dict[str, float]]:
            x_a, x_dict = x
            return x_a * 2., {k: v * 3. for k, v in x_dict.items()}

        x = (jnp.float_(.7), {"a": jnp.float_(.8), "b": jnp.float_(.9)})
        self.ConvertAndCompare(f_jax, x)

    def test_variable_input(self):
        f_jax = lambda x: jnp.sin(jnp.cos(x))
        f_tf = jax2tf.convert(f_jax)
        v = tf.Variable(0.7, dtype=dtypes.canonicalize_dtype(jnp.float_))
        self.assertIsInstance(f_tf(v), tf.Tensor)
        self.assertAllClose(f_jax(0.7), f_tf(v))

    def test_jit(self):
        f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
        self.ConvertAndCompare(f_jax, jnp.float_(0.7))

    def test_nested_jit(self):
        f_jax = jax.jit(lambda x: jnp.sin(jax.jit(jnp.cos)(x)))
        f_tf = jax2tf.convert(f_jax)
        np.testing.assert_allclose(f_jax(0.7), f_tf(0.7))

    def test_converts_jax_arrays(self):
        f_tf = tf.function(lambda x: x)
        self.assertEqual(f_tf(jnp.zeros([])).numpy(), 0.)
        self.assertEqual(f_tf(jnp.ones([])).numpy(), 1.)
        f_tf = tf.function(lambda x: x + x)
        self.assertEqual(f_tf(jnp.ones([])).numpy(), 2.)

        # Test with ShardedDeviceArray.
        n = jax.local_device_count()
        mk_sharded = lambda f: jax.pmap(lambda x: x)(f([n]))
        f_tf = tf.function(lambda x: x)
        self.assertAllClose(f_tf(mk_sharded(jnp.zeros)).numpy(), np.zeros([n]))
        self.assertAllClose(f_tf(mk_sharded(jnp.ones)).numpy(), np.ones([n]))

    @jtu.skip_on_devices("gpu")
    def test_bfloat16_passed_by_tf(self):
        f_jax = lambda a, b: a + b
        f_tf = tf.function(jax2tf.convert(f_jax),
                           input_signature=[
                               tf.TensorSpec([512, 512], tf.bfloat16),
                               tf.TensorSpec([512, 512], tf.bfloat16)
                           ])
        self.assertIsNotNone(f_tf.get_concrete_function())

    @jtu.skip_on_devices("gpu")
    def test_bfloat16_returned_by_jax(self):
        f_jax = lambda a, b: (a + b).astype(jnp.bfloat16)
        f_tf = jax2tf.convert(f_jax)
        self.assertEqual(f_tf(1., 2.).dtype, tf.bfloat16)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_dtype={dtype.__name__}", dtype=dtype)
            for dtype in [np.int64, np.float64]))
    def test_converts_64bit(self, dtype=np.int64, with_function=False):
        if not config.jax_enable_x64:
            self.skipTest("requires x64 mode")
        big_const = np.full((5, ), 2**33, dtype=dtype)
        self.ConvertAndCompare(jnp.sin, big_const)
        f_conv = jax2tf.convert(jnp.sin)
        if with_function:
            f_conv = tf.function(f_conv)
        # We check also when we pass tf.Variable or tf.Tensor into the
        # converted function
        self.assertAllClose(jnp.sin(big_const), f_conv(tf.Variable(big_const)))
        self.assertAllClose(jnp.sin(big_const), f_conv(tf.constant(big_const)))

    def test_function(self):
        f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
        self.ConvertAndCompare(f_jax, jnp.float_(0.7))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"function={with_function}",
                 with_function=with_function)
            for with_function in [False, True]))
    def test_gradients_disabled(self, with_function=False):
        f_tf = jax2tf.convert(jnp.tan, with_gradient=False)
        if with_function:
            f_tf = tf.function(f_tf, autograph=False)
        x = tf.ones([])

        # With tf.function the error is raised when we evaluate f_tf(x), in
        # eager mode when we evaluate tape.gradient(y, x)
        with self.assertRaisesRegex(
                LookupError,
                "Gradient explicitly disabled.*The jax2tf-converted function does not support gradients"
        ):
            with tf.GradientTape() as tape:
                tape.watch(x)
                y = f_tf(x)
                _ = tape.gradient(y, x)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"function={with_function}",
                 with_function=with_function)
            for with_function in [False, True]))
    def test_gradients(self, with_function=True):
        def f(x, y):
            return x * x, x * y

        f_tf = jax2tf.convert(f, with_gradient=True)
        if with_function:
            f_tf = tf.function(f_tf, autograph=False)
        default_float_type = dtypes.canonicalize_dtype(jnp.float_)
        x = tf.Variable(4., dtype=default_float_type)
        y = tf.Variable(5., dtype=default_float_type)
        with tf.GradientTape(persistent=True) as tape:
            u, v = f_tf(x, y)

        self.assertAllClose(2. * 4., tape.gradient(u, x))
        self.assertAllClose(0., tape.gradient(u, y))
        self.assertAllClose(5., tape.gradient(v, x))
        self.assertAllClose(4., tape.gradient(v, y))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"function={with_function}",
                 with_function=with_function)
            for with_function in [False, True]))
    def test_gradients_pytree(self, with_function=True):
        def f(xy: Tuple[float, float]) -> Dict[str, float]:
            x, y = xy
            return dict(one=x * x, two=x * y)

        f_tf = jax2tf.convert(f, with_gradient=True)
        if with_function:
            f_tf = tf.function(f_tf, autograph=False)
        default_float_dtype = dtypes.canonicalize_dtype(jnp.float_)
        x = tf.Variable(4., dtype=default_float_dtype)
        y = tf.Variable(5., dtype=default_float_dtype)
        with tf.GradientTape(persistent=True) as tape:
            uv = f_tf((x, y))

        self.assertAllClose(2. * 4., tape.gradient(uv["one"], x))
        self.assertAllClose(0., tape.gradient(uv["one"], y))
        self.assertAllClose(5., tape.gradient(uv["two"], x))
        self.assertAllClose(4., tape.gradient(uv["two"], y))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"function={with_function}",
                 with_function=with_function)
            for with_function in [False, True]))
    def test_gradients_with_custom_jvp(self, with_function=True):
        """Check gradients, for a function with custom JVP."""
        @jax.custom_jvp
        def f(x):
            return x * x

        @f.defjvp
        def f_jvp(primals, tangents):
            # 3 * x * x_t
            x, = primals
            x_dot, = tangents
            primal_out = f(x)
            tangent_out = 3. * x * x_dot
            return primal_out, tangent_out

        self.assertAllClose(4. * 4., f(4.))
        self.assertAllClose(3. * 4., jax.grad(f)(4.))

        f_tf = jax2tf.convert(f, with_gradient=True)
        if with_function:
            f_tf = tf.function(f_tf, autograph=False)
        self.assertAllClose(4. * 4., f_tf(jnp.float_(4.)))
        x = tf.Variable(4., dtype=dtypes.canonicalize_dtype(jnp.float_))
        with tf.GradientTape() as tape:
            tape.watch(x)
            y = f_tf(x)

        self.assertAllClose(4. * 4., y)
        self.assertAllClose(3. * 4., tape.gradient(y, x))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"function={with_function}",
                 with_function=with_function)
            for with_function in [False, True]))
    def test_gradients_with_custom_vjp(self, with_function=True):
        """Check gradients, for a function with custom VJP."""
        @jax.custom_vjp
        def f(x):
            return x * x

        # f_fwd: a -> (b, residual)
        def f_fwd(x):
            return f(x), 3. * x

        # f_bwd: (residual, CT b) -> [CT a]
        def f_bwd(residual, ct_b):
            return residual * ct_b,

        f.defvjp(f_fwd, f_bwd)

        self.assertAllClose(4. * 4., f(4.))
        self.assertAllClose(3. * 4., jax.grad(f)(4.))

        f_tf = jax2tf.convert(f, with_gradient=True)
        if with_function:
            f_tf = tf.function(f_tf, autograph=False)
        self.assertAllClose(4. * 4., f_tf(jnp.float_(4.)))
        x = tf.Variable(4., dtype=dtypes.canonicalize_dtype(jnp.float_))
        with tf.GradientTape() as tape:
            tape.watch(x)
            y = f_tf(x)

        self.assertAllClose(4. * 4., y)
        self.assertAllClose(3. * 4., tape.gradient(y, x))

    def test_convert_argument_non_callable_error(self):
        with self.assertRaisesRegex(TypeError, "Expected a callable value"):
            jax2tf.convert(5.)

    def test_convert_argument_non_tensor_error(self):
        with self.assertRaisesRegex(TypeError,
                                    "Argument.*should be NumPy array"):
            jax2tf.convert(lambda x: x)(lambda y: y)

    def test_argument_eager_tensor(self):
        x = jax2tf.convert(jnp.sin)(1.)
        jax2tf.convert(jnp.cos)(x)  # No error

    def test_checkpoint_wrapper_types(self):
        m = tf.Module()
        m.a = [tf.Module(), tf.Module()]
        m.b = (tf.Module(), tf.Module())
        m.c = {'a': tf.Module(), 'b': tf.Module()}
        self.assertNotEqual(type(m.a), list)
        self.assertNotEqual(type(m.b), tuple)
        self.assertNotEqual(type(m.c), dict)
        self.assertLen(jax.tree_leaves(m.a), 2)
        self.assertLen(jax.tree_leaves(m.b), 2)
        self.assertLen(jax.tree_leaves(m.c), 2)

    def test_custom_jvp(self):
        """Conversion of function with custom JVP"""
        @jax.custom_jvp
        def f(x):
            return x * x

        @f.defjvp
        def f_jvp(primals, tangents):
            x, = primals
            x_dot, = tangents
            primal_out = f(x)
            tangent_out = 3. * x * x_dot
            return primal_out, tangent_out

        arg = jnp.float_(0.7)
        self.TransformConvertAndCompare(f, arg, None)
        self.TransformConvertAndCompare(f, arg, "jvp")
        self.TransformConvertAndCompare(f, arg, "vmap")
        self.TransformConvertAndCompare(f, arg, "jvp_vmap")
        self.TransformConvertAndCompare(f, arg, "grad")
        self.TransformConvertAndCompare(f, arg, "grad_vmap")

    def test_custom_vjp(self):
        """Conversion of function with custom VJP"""
        @jax.custom_vjp
        def f(x):
            return x * x

        # f_fwd: a -> (b, residual)
        def f_fwd(x):
            return f(x), 3. * x

        # f_bwd: (residual, CT b) -> [CT a]
        def f_bwd(residual, ct_b):
            return residual * ct_b,

        f.defvjp(f_fwd, f_bwd)
        arg = jnp.float_(0.7)
        self.TransformConvertAndCompare(f, arg, None)
        self.TransformConvertAndCompare(f, arg, "vmap")
        self.TransformConvertAndCompare(f, arg, "grad")
        self.TransformConvertAndCompare(f, arg, "grad_vmap")

    def test_remat1(self):
        @jax.remat
        def f(x1):
            x2 = jnp.sin(x1)
            x3 = jnp.sin(x2)
            x4 = jnp.sin(x3)
            return jnp.sum(x4)

        # The computation of grad_f computes "sin" 5 times, 3 for the forward pass
        # and then to rematerialize "x2" and "x3" in the backward pass.
        arg = np.arange(3.)
        self.TransformConvertAndCompare(f, arg, "grad")
        # TODO: check that the TF code also computes "sin" 5 times

    def test_remat_free_var(self):
        def f(x):
            y = 2 * x

            @jax.remat
            def g():
                return y

            return g()

        arg = jnp.float_(3.)
        self.TransformConvertAndCompare(f, arg, None)
        self.TransformConvertAndCompare(f, arg, "grad")

    def test_convert_nullary_func(self):
        # Even nullary functions are converted to TF (as opposed to constant-folded
        # in JAX prior to conversion).
        def f_jax():
            return jnp.sin(1.)

        f_tf = tf.function(jax2tf.convert(f_jax), autograph=False)
        f_tf_graph = f_tf.get_concrete_function().graph.as_graph_def()
        self.assertIn('op: "Sin"', str(f_tf_graph))

    def test_convert_of_nested_independent_jit(self):
        def func(x):
            def inner1(y):
                return x + y

            # The JIT does not have data dependency
            return jax.jit(inner1)(1.)

        jax2tf.convert(func)(2.)

    def test_convert_of_nested_dependent_jit(self):
        def func(x):
            def inner1(y):
                return x + y

            # The JIT does have data dependency
            return jax.jit(inner1)(x)

        jax2tf.convert(func)(2.)  # No error

    def test_nested_convert_error(self):
        def outer(y):
            return jax2tf.convert(jnp.sin)(
                y)  # Inner convert takes tracer args

        with self.assertRaisesRegex(
                ValueError,
                "convert must be used outside all JAX transformations"):
            jax2tf.convert(outer)(np.ones((4, )))

    def test_nested_convert_error_non_tracer(self):
        """The inner convert takes non-tracer arguments"""
        def outer(y):
            sin_1 = jax2tf.convert(jnp.sin)(
                1.)  # Inner convert takes non-tracer arg
            return y + sin_1

        with self.assertRaisesRegex(
                ValueError,
                "convert must be used outside all JAX transformations"):
            jax2tf.convert(outer)(2.)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_{transform}", transform=transform)
            for transform in ["jit", "jvp", "grad", "vmap"]))
    def test_convert_under_transform_error(self, transform="vmap"):
        def outer(y):
            return jax2tf.convert(jnp.sin)(
                y)  # Inner convert takes tracer args

        with self.assertRaisesRegex(
                ValueError,
                "convert must be used outside all JAX transformations"):
            self.TransformConvertAndCompare(outer, np.ones((4, )), transform)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_{transform}", transform=transform)
            for transform in ["jit", "jvp", "grad", "vmap"]))
    def test_convert_under_transform_error_non_tracer(self, transform="vmap"):
        def outer(y):
            sin_1 = jax2tf.convert(jnp.sin)(
                1.)  # Inner convert takes non-tracer arg
            return y + sin_1

        with self.assertRaisesRegex(
                ValueError,
                "convert must be used outside all JAX transformations"):
            self.TransformConvertAndCompare(outer, np.ones((4, )), transform)

    def test_name_scope(self):
        log = []

        @jax.named_call
        def my_test_function(x):
            y = tf.Variable(1., name="foo")
            log.append(y.name)
            return x * x

        jax2tf.convert(my_test_function)(2)
        self.assertIn("my_test_function/foo", log[0])

    def test_bfloat16_constant(self):
        # Re: https://github.com/google/jax/issues/3942
        def jax_fn_scalar(x):
            x = x.astype(jnp.bfloat16)
            x *= 2.
            return x

        def jax_fn_array(x):
            x = x.astype(jnp.bfloat16)
            x *= np.array([1.5, 2.5, 3.5], jnp.bfloat16)
            return x

        tf_fn_scalar = jax2tf.convert(jax_fn_scalar)
        self.assertAllClose(tf_fn_scalar(1.375).numpy(), jnp.bfloat16(2.750))

        tf_fn_array = jax2tf.convert(jax_fn_array)
        self.assertAllClose(tf_fn_array(np.array([3, 4, 5])),
                            np.array([4.5, 10, 17.5], jnp.bfloat16))
Ejemplo n.º 19
0
class FftTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inverse={}_shape={}_axes={}".format(
                inverse, jtu.format_shape_dtype_string(shape, dtype), axes),
            "axes":
            axes,
            "shape":
            shape,
            "dtype":
            dtype,
            "rng_factory":
            rng_factory,
            "inverse":
            inverse
        } for inverse in [False, True] for rng_factory in [jtu.rand_default]
                            for dtype in all_dtypes
                            for shape in [(10, ), (10, 10), (2, 3,
                                                             4), (2, 3, 4, 5)]
                            for axes in _get_fftn_test_axes(shape)))
    def testFftn(self, inverse, shape, dtype, axes, rng_factory):
        rng = rng_factory()
        args_maker = lambda: (rng(shape, dtype), )
        np_op = np.fft.ifftn if inverse else np.fft.fftn
        onp_op = onp.fft.ifftn if inverse else onp.fft.fftn
        np_fn = lambda a: np_op(a, axes=axes)
        onp_fn = lambda a: onp_op(a, axes=axes)
        # Numpy promotes to complex128 aggressively.
        self._CheckAgainstNumpy(onp_fn,
                                np_fn,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(np_fn, args_maker, check_dtypes=True)
        # Test gradient for differentiable types.
        if dtype in inexact_dtypes:
            # TODO(skye): can we be more precise?
            tol = 1e-1
            jtu.check_grads(np_fn, args_maker(), order=1, atol=tol, rtol=tol)
            jtu.check_grads(np_fn, args_maker(), order=2, atol=tol, rtol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_inverse={}".format(inverse),
            "inverse": inverse
        } for inverse in [False, True]))
    def testFftnErrors(self, inverse):
        rng = jtu.rand_default()
        name = 'ifftn' if inverse else 'fftn'
        func = np.fft.ifftn if inverse else np.fft.fftn
        self.assertRaisesRegex(
            ValueError, "jax.np.fft.{} only supports 1D, 2D, and 3D FFTs. "
            "Got axes None with input rank 4.".format(name),
            lambda: func(rng([2, 3, 4, 5], dtype=onp.float64), axes=None))
        self.assertRaisesRegex(
            ValueError,
            "jax.np.fft.{} does not support repeated axes. Got axes \\[1, 1\\]."
            .format(name),
            lambda: func(rng([2, 3], dtype=onp.float64), axes=[1, 1]))
        self.assertRaises(
            ValueError, lambda: func(rng([2, 3], dtype=onp.float64), axes=[2]))
        self.assertRaises(
            ValueError,
            lambda: func(rng([2, 3], dtype=onp.float64), axes=[-3]))
Ejemplo n.º 20
0
class LaxBackedScipySignalTests(jtu.JaxTestCase):
    """Tests for LAX-backed scipy.stats implementations"""
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_op={}_xshape=[{}]_yshape=[{}]_mode={}".format(
                op, jtu.format_shape_dtype_string(xshape, dtype),
                jtu.format_shape_dtype_string(yshape, dtype), mode),
            "xshape":
            xshape,
            "yshape":
            yshape,
            "dtype":
            dtype,
            "mode":
            mode,
            "jsp_op":
            getattr(jsp_signal, op),
            "osp_op":
            getattr(osp_signal, op)
        } for mode in ['full', 'same', 'valid']
                            for op in ['convolve', 'correlate']
                            for dtype in default_dtypes
                            for xshape in onedim_shapes
                            for yshape in onedim_shapes))
    def testConvolutions(self, xshape, yshape, dtype, mode, jsp_op, osp_op):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
        osp_fun = partial(osp_op, mode=mode)
        jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST)
        tol = {onp.float16: 1e-2, onp.float32: 1e-2, onp.float64: 1e-8}
        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "op={}_xshape=[{}]_yshape=[{}]_mode={}".format(
                op, jtu.format_shape_dtype_string(xshape, dtype),
                jtu.format_shape_dtype_string(yshape, dtype), mode),
            "xshape":
            xshape,
            "yshape":
            yshape,
            "dtype":
            dtype,
            "mode":
            mode,
            "jsp_op":
            getattr(jsp_signal, op),
            "osp_op":
            getattr(osp_signal, op)
        } for mode in ['full', 'same', 'valid']
                            for op in ['convolve2d', 'correlate2d']
                            for dtype in default_dtypes
                            for xshape in twodim_shapes
                            for yshape in twodim_shapes))
    def testConvolutions2D(self, xshape, yshape, dtype, mode, jsp_op, osp_op):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
        osp_fun = partial(osp_op, mode=mode)
        jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST)
        tol = {onp.float16: 1e-2, onp.float32: 1e-2, onp.float64: 1e-14}
        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True)
Ejemplo n.º 21
0
class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):

  def test_primitive_coverage(self):
    """Fail if there are JAX primitives that are not implemented."""
    # Harvest primitives from XLA translation tables
    all_primitives = (set(xla.translations)
                      | set(xla.backend_specific_translations['cpu'])
                      | set(xla.backend_specific_translations['gpu'])
                      | set(xla.backend_specific_translations['tpu'])
                      | set(xla.initial_style_translations)
                      | set(xla.parallel_translations))

    tf_impl = set(jax.experimental.jax2tf.jax2tf.tf_impl) | set(jax.experimental.jax2tf.jax2tf.tf_impl_with_avals)
    tf_not_yet_impl = set(jax.experimental.jax2tf.jax2tf.tf_not_yet_impl)

    all_primitives = tuple(sorted(all_primitives, key=str))
    for p in all_primitives:
      # TODO: remove tie_in once omnistaging is on by default
      if p.name == "axis_index" or p.name == "tie_in":
        continue
      if p in tf_not_yet_impl:
        self.assertNotIn(p, tf_impl)  # Should not be in both tf_impl and tf_not_yet_impl
      else:
        self.assertIn(p, tf_impl)

  @parameterized.named_parameters(
    dict(testcase_name=f"_{f_jax.__name__}",
         f_jax=f_jax)
    for f_jax in [jnp.add, jnp.subtract, jnp.multiply, jnp.divide,
                  jnp.less, jnp.less_equal, jnp.equal, jnp.greater,
                  jnp.greater_equal, jnp.not_equal, jnp.maximum,
                  jnp.minimum])
  def test_type_promotion(self, f_jax=jnp.add):
    # We only test a few types here, as tensorflow does not support many
    # types like uint* or bool in binary ops.
    types = [dtypes.bfloat16, np.int32, np.int64, np.float32]
    for x_dtype in types:
      for y_dtype in types:
        x = np.array([1, 2], dtype=x_dtype)
        y = np.array([3, 4], dtype=y_dtype)
        self.ConvertAndCompare(f_jax, x, y)

  def test_concat(self):
    values = [np.array([1, 2], dtype=np.float32),
              np.array([1, 2], dtype=np.int32),
              np.array([1, 2], dtype=np.int8)]
    f_jax = jax.jit(lambda x: jnp.concatenate(x, axis=0))
    self.ConvertAndCompare(f_jax, values)

  @primitive_harness.parameterized(primitive_harness.lax_pad)
  def test_pad(self, harness: primitive_harness.Harness):
    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))

  @primitive_harness.parameterized(primitive_harness.lax_top_k)
  def test_top_k(self, harness: primitive_harness.Harness):
    if (harness.params["k"] > harness.params["shape"][-1] or
        harness.params["k"] < 0):
      with self.assertRaisesRegex(ValueError, "k argument to top_k must be"):
        harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
    elif harness.params["dtype"] in jtu.dtypes.complex:
      # TODO(necula): fix top_k complex bug on TPU
      if jtu.device_under_test() == "tpu":
        raise unittest.SkipTest("top_k complex on TPU raises different error")
      with self.assertRaisesRegex(RuntimeError, "Unimplemented: complex comparison"):
        harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
    # TODO: TF and JAX sort [inf, nan] differently.
    elif harness.name.startswith("nan_"):
      raise unittest.SkipTest("inconsistent [nan, inf] sorting")
    else:
      self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))

  @primitive_harness.parameterized(primitive_harness.lax_sort)
  def test_sort(self, harness: primitive_harness.Harness):
    if (jtu.device_under_test() == "gpu" and
        len(harness.arg_descriptors) == 4 and
        not harness.params["is_stable"]):
      # TODO: fix the TF GPU test
      raise unittest.SkipTest("GPU tests are running TF on CPU")
    if jtu.device_under_test() == "tpu" and harness.params["dtype"] in jtu.dtypes.complex:
      raise unittest.SkipTest("JAX sort is not implemented on TPU for complex")
    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))

  @primitive_harness.parameterized(primitive_harness.lax_fft)
  @jtu.skip_on_flag("jax_skip_slow_tests", True)
  def test_fft(self, harness: primitive_harness.Harness):
    if len(harness.params["fft_lengths"]) > 3:
      with self.assertRaisesRegex(RuntimeError, "FFT only supports ranks 1-3"):
        harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
    elif (jtu.device_under_test() == "tpu" and
          len(harness.params["fft_lengths"]) > 1):
      # TODO(b/140351181): FFT is mostly unimplemented on TPU, even for JAX
      with self.assertRaisesRegex(RuntimeError,
                                  "only 1D FFT is currently supported."):
        harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
    else:
      tol = None
      if jtu.device_under_test() == "gpu":
        if harness.params["dtype"] in jtu.dtypes.boolean:
          tol = 0.01
        else:
          tol = 1e-3
      self.ConvertAndCompare(harness.dyn_fun,
                             *harness.dyn_args_maker(self.rng()),
                             atol=tol, rtol=tol)

  @primitive_harness.parameterized(primitive_harness.lax_linalg_cholesky)
  def test_cholesky(self, harness: primitive_harness.Harness):
    dtype = harness.params["dtype"]
    if dtype in [dtypes.bfloat16, np.float16]:
      raise unittest.SkipTest("Cholesky decomposition not supported for "
                              "(b)float16 in JAX.")
    operand = harness.dyn_args_maker(self.rng())[0]
    operand = np.matmul(operand, jnp.conj(np.swapaxes(operand, -1, -2)))
    tol = None
    # TODO(bchetioui): very high discrepancy in the float32/complex64 case
    if dtype in [np.float32, np.complex64]:
      tol = 1e-2
    # TODO(bchetioui): also high discrepancy in the float64/complex128 case
    elif dtype in [np.float64, np.complex128]:
      tol = 1e-11

    def custom_assert(result_jax, result_tf):
      # cholesky_p returns garbage in the strictly upper triangular part of the
      # result, so we can safely ignore that part.
      self.assertAllClose(jnp.tril(result_jax), result_tf, atol=tol)

    self.ConvertAndCompare(harness.dyn_fun, operand,
                           custom_assert=custom_assert,
                           always_custom_assert=True)

  @primitive_harness.parameterized(primitive_harness.lax_linalg_qr)
  def test_qr(self, harness: primitive_harness.Harness):
    # See jax.lib.lapack.geqrf for the list of compatible types

    dtype = harness.params["dtype"]
    dut = jtu.device_under_test()
    # These cases are not implemented in JAX
    if dtype in (jtu.dtypes.all_integer + [jnp.bfloat16]):
      unimplemented_jax = True
    elif dtype is np.complex64 and dut == "tpu":
      unimplemented_jax = True
    elif dtype is np.float16 and dut in ("cpu", "gpu"):
      unimplemented_jax = True
    else:
      unimplemented_jax = False

    if unimplemented_jax:
      raise unittest.SkipTest(f"QR not implemented in JAX for {dtype} on {dut}")

    # TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824.
    # - for now, the performance of the HLO QR implementation called when
    #   compiling with TF is expected to have worse performance than the
    #   custom calls made in JAX.
    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
                           atol=1e-5, rtol=1e-5)

  @primitive_harness.parameterized(primitive_harness.lax_linalg_svd)
  @jtu.skip_on_flag("jax_skip_slow_tests", True)
  def test_svd(self, harness: primitive_harness.Harness):
    if harness.params["dtype"] in [np.float16, dtypes.bfloat16]:
      if jtu.device_under_test() != "tpu":
        # Does not work in JAX
        with self.assertRaisesRegex(NotImplementedError, "Unsupported dtype"):
          harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
        return

    if harness.params["dtype"] in [np.complex64, np.complex128]:
      if jtu.device_under_test() == "tpu":
        # TODO: on JAX on TPU there is no SVD implementation for complex
        with self.assertRaisesRegex(RuntimeError,
                                    "Binary op compare with different element types"):
          harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
        return

    def _custom_assert(r_jax, r_tf, atol=1e-6, rtol=1e-6):
      def _reconstruct_operand(result, is_tf: bool):
        # Reconstructing operand as documented in numpy.linalg.svd (see
        # https://numpy.org/doc/stable/reference/generated/numpy.linalg.svd.html)
        s, u, v = result
        if is_tf:
          s = s.numpy()
          u = u.numpy()
          v = v.numpy()
        U = u[..., :s.shape[-1]]
        V = v[..., :s.shape[-1], :]
        S = s[..., None, :]
        return jnp.matmul(U * S, V), s.shape, u.shape, v.shape

      if harness.params["compute_uv"]:
        r_jax_reconstructed = _reconstruct_operand(r_jax, False)
        r_tf_reconstructed = _reconstruct_operand(r_tf, True)
        self.assertAllClose(r_jax_reconstructed, r_tf_reconstructed,
                            atol=atol, rtol=rtol)
      else:
        self.assertAllClose(r_jax, r_tf, atol=atol, rtol=rtol)

    tol = 1e-4
    custom_assert = partial(_custom_assert, atol=tol, rtol=tol)

    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
                           atol=tol, rtol=tol,
                           custom_assert=custom_assert,
                           always_custom_assert=True)

  @primitive_harness.parameterized(primitive_harness.lax_select_and_gather_add)
  @jtu.ignore_warning(category=UserWarning,
                      message="Using reduced precision for gradient.*")
  def test_select_and_gather_add(self, harness: primitive_harness.Harness):
    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))

  @primitive_harness.parameterized(primitive_harness.lax_reduce_window)
  def test_reduce_window(self, harness: primitive_harness.Harness):
    dtype = harness.params['dtype']

    if (jtu.device_under_test() == 'tpu' and dtype is np.complex64):
      raise unittest.SkipTest(
          'TODO: JAX reduce_window on TPU does not handle complex64'
      )

    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))

  @primitive_harness.parameterized(primitive_harness.lax_linalg_eig)
  def test_eig(self, harness: primitive_harness.Harness):
    operand = harness.dyn_args_maker(self.rng())[0]
    compute_left_eigenvectors = harness.params["compute_left_eigenvectors"]
    compute_right_eigenvectors = harness.params["compute_right_eigenvectors"]
    dtype = harness.params["dtype"]

    if jtu.device_under_test() != "cpu":
      raise unittest.SkipTest("eig only supported on CPU in JAX")

    if dtype in [np.float16, dtypes.bfloat16]:
      raise unittest.SkipTest("eig unsupported with (b)float16 in JAX")

    def custom_assert(result_jax, result_tf):
      result_tf = tuple(map(lambda e: e.numpy(), result_tf))
      inner_dimension = operand.shape[-1]
      # Test ported from tests.lax_test.testEig
      # Norm, adjusted for dimension and type.
      def norm(x):
        norm = np.linalg.norm(x, axis=(-2, -1))
        return norm / ((inner_dimension + 1) * jnp.finfo(dtype).eps)

      def check_right_eigenvectors(a, w, vr):
        self.assertTrue(
          np.all(norm(np.matmul(a, vr) - w[..., None, :] * vr) < 100))

      def check_left_eigenvectors(a, w, vl):
        rank = len(a.shape)
        aH = jnp.conj(a.transpose(list(range(rank - 2)) + [rank - 1, rank - 2]))
        wC = jnp.conj(w)
        check_right_eigenvectors(aH, wC, vl)

      def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array):
        tol = None
        # TODO(bchetioui): numerical discrepancies
        if dtype in [np.float32, np.complex64]:
          tol = 1e-4
        elif dtype in [np.float64, np.complex128]:
          tol = 1e-13
        closest_diff = min(abs(eigenvalues_array - eigenvalue))
        self.assertAllClose(closest_diff, np.array(0., closest_diff.dtype),
                            atol=tol)

      all_w_jax, all_w_tf = result_jax[0], result_tf[0]
      for idx in itertools.product(*map(range, operand.shape[:-2])):
        w_jax, w_tf = all_w_jax[idx], all_w_tf[idx]
        for i in range(inner_dimension):
          check_eigenvalue_is_in_array(w_jax[i], w_tf)
          check_eigenvalue_is_in_array(w_tf[i], w_jax)

      if compute_left_eigenvectors:
        check_left_eigenvectors(operand, all_w_tf, result_tf[1])
      if compute_right_eigenvectors:
        check_right_eigenvectors(operand, all_w_tf,
                                 result_tf[1 + compute_left_eigenvectors])

    self.ConvertAndCompare(harness.dyn_fun, operand,
                           custom_assert=custom_assert)

  @primitive_harness.parameterized(primitive_harness.lax_linalg_eigh)
  def test_eigh(self, harness: primitive_harness.Harness):
    operand = harness.dyn_args_maker(self.rng())[0]
    lower = harness.params["lower"]
    # Make operand self-adjoint
    operand = (operand + np.conj(np.swapaxes(operand, -1, -2))) / 2
    # Make operand lower/upper triangular
    triangular_operand = np.tril(operand) if lower else np.triu(operand)
    dtype = harness.params["dtype"]

    if (dtype in [np.complex64, np.complex128] and
        jtu.device_under_test() == "tpu"):
      raise unittest.SkipTest("TODO: complex eigh not supported on TPU in JAX")

    def custom_assert(result_jax, result_tf):
      result_tf = tuple(map(lambda e: e.numpy(), result_tf))
      inner_dimension = operand.shape[-1]

      def check_right_eigenvectors(a, w, vr):
        tol = 1e-16
        # TODO(bchetioui): tolerance needs to be very high in compiled mode,
        # specifically for eigenvectors.
        if dtype == np.float64:
          tol = 1e-6
        elif dtype == np.float32:
          tol = 1e-2
        elif dtype in [dtypes.bfloat16, np.complex64]:
          tol = 1e-3
        elif dtype == np.complex128:
          tol = 1e-13
        self.assertAllClose(np.matmul(a, vr) - w[..., None, :] * vr,
                            np.zeros(a.shape, dtype=vr.dtype),
                            atol=tol)

      def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array):
        tol = None
        if dtype in [dtypes.bfloat16, np.float32, np.complex64]:
          tol = 1e-3
        elif dtype in [np.float64, np.complex128]:
          tol = 1e-11
        closest_diff = min(abs(eigenvalues_array - eigenvalue))
        self.assertAllClose(closest_diff, np.array(0., closest_diff.dtype),
                            atol=tol)

      _, all_w_jax = result_jax
      all_vr_tf, all_w_tf = result_tf

      for idx in itertools.product(*map(range, operand.shape[:-2])):
        w_jax, w_tf = all_w_jax[idx], all_w_tf[idx]
        for i in range(inner_dimension):
          check_eigenvalue_is_in_array(w_jax[i], w_tf)
          check_eigenvalue_is_in_array(w_tf[i], w_jax)

      check_right_eigenvectors(operand, all_w_tf, all_vr_tf)

    # On CPU and GPU, JAX makes custom calls
    always_custom_assert = True
    # On TPU, JAX calls xops.Eigh
    if jtu.device_under_test == "tpu":
      always_custom_assert = False

    self.ConvertAndCompare(harness.dyn_fun, triangular_operand,
                           custom_assert=custom_assert,
                           always_custom_assert=always_custom_assert)

  @primitive_harness.parameterized(
      primitive_harness.lax_linalg_triangular_solve)
  def test_triangular_solve(self, harness: primitive_harness.Harness):
    dtype = harness.params["dtype"]
    if dtype == np.float16 and jtu.device_under_test() == "gpu":
      raise unittest.SkipTest(
        f"Triangular solve is not implemented in JAX for dtype {dtype}")
    atol = rtol = None
    if dtype == np.float32:
      atol = rtol = 1e-5
    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
                           atol=atol, rtol=rtol)

  @primitive_harness.parameterized(primitive_harness.lax_unary_elementwise)
  def test_unary_elementwise(self, harness: primitive_harness.Harness):
    dtype = harness.params["dtype"]
    lax_name = harness.params["lax_name"]
    arg, = harness.dyn_args_maker(self.rng())
    custom_assert = None
    if lax_name == "digamma":
      # TODO(necula): fix bug with digamma/(f32|f16) on TPU
      if dtype in [np.float16, np.float32] and jtu.device_under_test() == "tpu":
        raise unittest.SkipTest("TODO: fix bug: nan vs not-nan")

      # In the bfloat16 case, TF and lax both return NaN in undefined cases.
      if not dtype is dtypes.bfloat16:
        # digamma is not defined at 0 and -1
        def custom_assert(result_jax, result_tf):
          # lax.digamma returns NaN and tf.math.digamma returns inf
          special_cases = (arg == 0.) | (arg == -1.)
          nr_special_cases = np.count_nonzero(special_cases)
          self.assertAllClose(np.full((nr_special_cases,), dtype(np.nan)),
                              result_jax[special_cases])
          self.assertAllClose(np.full((nr_special_cases,), dtype(np.inf)),
                              result_tf[special_cases])
          # non-special cases are equal
          self.assertAllClose(result_jax[~ special_cases],
                              result_tf[~ special_cases])
    if lax_name == "erf_inv":
      # TODO(necula): fix erf_inv bug on TPU
      if jtu.device_under_test() == "tpu":
        raise unittest.SkipTest("erf_inv bug on TPU: nan vs non-nan")
      # TODO: investigate: in the (b)float16 cases, TF and lax both return the
      # same result in undefined cases.
      if not dtype in [np.float16, dtypes.bfloat16]:
        # erf_inv is not defined for arg <= -1 or arg >= 1
        def custom_assert(result_jax, result_tf):  # noqa: F811
          # for arg < -1 or arg > 1
          # lax.erf_inv returns NaN; tf.math.erf_inv return +/- inf
          special_cases = (arg < -1.) | (arg > 1.)
          nr_special_cases = np.count_nonzero(special_cases)
          self.assertAllClose(np.full((nr_special_cases,), dtype(np.nan),
                                      dtype=dtype),
                              result_jax[special_cases])
          signs = np.where(arg[special_cases] < 0., -1., 1.)
          self.assertAllClose(np.full((nr_special_cases,),
                                      signs * dtype(np.inf), dtype=dtype),
                              result_tf[special_cases])
          # non-special cases are equal
          self.assertAllClose(result_jax[~ special_cases],
                              result_tf[~ special_cases])
    atol = None
    if jtu.device_under_test() == "gpu":
      # TODO(necula): revisit once we fix the GPU tests
      atol = 1e-3
    self.ConvertAndCompare(harness.dyn_fun, arg, custom_assert=custom_assert,
                           atol=atol)

  @primitive_harness.parameterized(primitive_harness.lax_bitwise_not)
  def test_bitwise_not(self, harness):
    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))

  @primitive_harness.parameterized(primitive_harness.lax_population_count)
  def test_population_count(self, harness: primitive_harness.Harness):
    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))

  @primitive_harness.parameterized(primitive_harness.lax_add_mul)
  def test_add_mul(self, harness: primitive_harness.Harness):
    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))

  @primitive_harness.parameterized(primitive_harness.lax_min_max)
  def test_min_max(self, harness: primitive_harness.Harness):
    # TODO(bchetioui): discrepancies between TF & JAX when comparing with NaN;
    # JAX always returns NaN, while TF returns the value NaN is compared with.
    def custom_assert(result_jax, result_tf):
      mask = np.isnan(result_jax)
      self.assertAllClose(result_jax[~ mask], result_tf[~ mask])
    # TODO(bchetioui): figure out why we need always_custom_assert=True
    always_custom_assert = True
    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
                           custom_assert=custom_assert,
                           always_custom_assert=always_custom_assert)

  @primitive_harness.parameterized(primitive_harness.lax_binary_elementwise)
  def test_binary_elementwise(self, harness):
    tol = None
    lax_name, dtype = harness.params["lax_name"], harness.params["dtype"]
    if lax_name in ("igamma", "igammac"):
      # TODO(necula): fix bug with igamma/f16
      if dtype in [np.float16, dtypes.bfloat16]:
        raise unittest.SkipTest("TODO: igamma(c) unsupported with (b)float16 in JAX")
      # TODO(necula): fix bug with igamma/f32 on TPU
      if dtype is np.float32 and jtu.device_under_test() == "tpu":
        raise unittest.SkipTest("TODO: fix bug: nan vs not-nan")
    arg1, arg2 = harness.dyn_args_maker(self.rng())
    custom_assert = None
    if lax_name == "igamma":
      # igamma is not defined when the first argument is <=0
      def custom_assert(result_jax, result_tf):
        # lax.igamma returns NaN when arg1 == arg2 == 0; tf.math.igamma returns 0
        special_cases = (arg1 == 0.) & (arg2 == 0.)
        nr_special_cases = np.count_nonzero(special_cases)
        self.assertAllClose(np.full((nr_special_cases,), np.nan, dtype=dtype),
                            result_jax[special_cases])
        self.assertAllClose(np.full((nr_special_cases,), 0., dtype=dtype),
                            result_tf[special_cases])
        # non-special cases are equal
        self.assertAllClose(result_jax[~ special_cases],
                            result_tf[~ special_cases])
    if lax_name == "igammac":
      # On GPU, tolerance also needs to be adjusted in compiled mode
      if dtype == np.float64 and jtu.device_under_test() == 'gpu':
        tol = 1e-14
      # igammac is not defined when the first argument is <=0
      def custom_assert(result_jax, result_tf):  # noqa: F811
        # lax.igammac returns 1. when arg1 <= 0; tf.math.igammac returns NaN
        special_cases = (arg1 <= 0.) | (arg2 <= 0)
        nr_special_cases = np.count_nonzero(special_cases)
        self.assertAllClose(np.full((nr_special_cases,), 1., dtype=dtype),
                            result_jax[special_cases])
        self.assertAllClose(np.full((nr_special_cases,), np.nan, dtype=dtype),
                            result_tf[special_cases])
        # On CPU, tolerance only needs to be adjusted in eager & graph modes
        tol = None
        if dtype == np.float64:
          tol = 1e-14

        # non-special cases are equal
        self.assertAllClose(result_jax[~ special_cases],
                            result_tf[~ special_cases], atol=tol, rtol=tol)
    self.ConvertAndCompare(harness.dyn_fun, arg1, arg2,
                           custom_assert=custom_assert, atol=tol, rtol=tol)

  @primitive_harness.parameterized(primitive_harness.lax_binary_elementwise_logical)
  def test_binary_elementwise_logical(self, harness):
    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))


  @primitive_harness.parameterized(primitive_harness.lax_betainc)
  def test_betainc(self, harness: primitive_harness.Harness):
    dtype = harness.params["dtype"]
    # TODO: https://www.tensorflow.org/api_docs/python/tf/math/betainc only
    # supports float32/64 tests.
    # TODO(bchetioui): investigate why the test actually fails in JAX.
    if dtype in [np.float16, dtypes.bfloat16]:
      raise unittest.SkipTest("(b)float16 not implemented in TF")

    tol = None
    if dtype is np.float64:
      tol = 1e-14

    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
                           atol=tol, rtol=tol)

  # TODO(necula): combine tests that are identical except for the harness
  # wait until we get more experience with using harnesses.
  @primitive_harness.parameterized(primitive_harness.lax_shift_left)
  def test_shift_left(self, harness):
    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))

  @primitive_harness.parameterized(primitive_harness.lax_shift_right_logical)
  def test_shift_right_logical(self, harness):
    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))

  @primitive_harness.parameterized(primitive_harness.lax_shift_right_arithmetic)
  def test_shift_right_arithmetic(self, harness):
    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))

  @primitive_harness.parameterized(primitive_harness.lax_slice)
  def test_slice(self, harness):
    # JAX.slice rejects negative indices; check, and skip jax2tf
    if any(si < 0 or si >= sh or li < 0 or li > sh
           for sh, si, li in zip(harness.params["shape"],
                                 harness.params["start_indices"],
                                 harness.params["limit_indices"])):
      with self.assertRaisesRegex(TypeError, ""):
        harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
    else:
      self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))

  @primitive_harness.parameterized(primitive_harness.lax_dynamic_slice)
  def test_dynamic_slice(self, harness):
    # JAX.dynamic_slice rejects slice sizes too big; check this, and skip jax2tf
    args = harness.dyn_args_maker(self.rng())
    if any(li - si < 0 or li - si >= sh
           for sh, si, li in zip(harness.params["shape"],
                                 harness.params["start_indices"],
                                 harness.params["limit_indices"])):
      with self.assertRaisesRegex(TypeError, ""):
        harness.dyn_fun(*args)
      return

    self.ConvertAndCompare(harness.dyn_fun, *args)

  @primitive_harness.parameterized(primitive_harness.lax_dynamic_update_slice)
  def test_dynamic_update_slice(self, harness):
    # JAX.dynamic_update_slice rejects update slices too big; check, and skip jax2tf
    if any(ush > sh
           for sh, ush in zip(harness.params["shape"],
                              harness.params["update_shape"])):
      with self.assertRaisesRegex(TypeError, ""):
        harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
    else:
      self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))

  @primitive_harness.parameterized(primitive_harness.lax_squeeze)
  def test_squeeze(self, harness: primitive_harness.Harness):
    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))

  @primitive_harness.parameterized(primitive_harness.lax_dot_general)
  def test_dot_general(self, harness: primitive_harness.Harness):
    tol, dtype = None, harness.params["dtype"]
    if dtype == dtypes.bfloat16:
      tol = 0.3
    elif dtype in [np.complex64, np.float32]:
      if jtu.device_under_test() == "tpu":
        tol = 0.1 if dtype == np.float32 else 0.3
      else:
        tol = 1e-5
    elif dtype == np.float16:
      if jtu.device_under_test() == "gpu":
        tol = 0.1
      else:
        tol = 0.01
    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
                           atol=tol, rtol=tol)

  @primitive_harness.parameterized(primitive_harness.lax_conv_general_dilated)
  def test_conv_general_dilated(self, harness: primitive_harness.Harness):
    dtype, device = harness.params["dtype"], jtu.device_under_test()
    if device == "gpu" and dtype in [np.complex64, np.complex128]:
      raise unittest.SkipTest("TODO: crash on GPU in TF")

    tol = None
    if device == "gpu":
      tol = 1e-4
    elif device == "tpu":
      tol = 1e-3
    # TODO(bchetioui): significant discrepancies in some float16 cases.
    if dtype == np.float16:
      tol = 1.
    # TODO(bchetioui): slight occasional discrepancy in float32 cases.
    elif dtype == np.float32:
      tol = 0.5 if device == "tpu" else (1e-3 if device == "gpu" else 1e-4)
    elif dtype == np.complex64 and device == "tpu":
      tol = 0.1
    # TODO(bchetioui): slight discrepancy when going through the path using
    # tf.nn.convolution.
    elif dtype == np.float64 and device == "cpu":
      tol = 1e-13
    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
                           atol=tol, rtol=tol)

  @primitive_harness.parameterized(primitive_harness.lax_gather)
  def test_gather(self, harness: primitive_harness.Harness):
    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))

  @primitive_harness.parameterized(primitive_harness.lax_scatter)
  def test_scatter(self, harness: primitive_harness.Harness):
    f_name = harness.params['f_lax'].__name__
    dtype = harness.params['dtype']

    if jtu.device_under_test() == 'tpu':
      if dtype is np.complex64 and f_name in ['scatter_min', 'scatter_max']:
          raise unittest.SkipTest(f"TODO: complex {f_name} on TPU fails in JAX")

    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))

  def test_boolean_gather(self):
    values = np.array([[True, True], [False, True], [False, False]],
                      dtype=np.bool_)
    indices = np.array([0, 1], dtype=np.int32)
    for axis in [0, 1]:
      f_jax = jax.jit(lambda v, i: jnp.take(v, i, axis=axis))  # pylint: disable=cell-var-from-loop
      self.ConvertAndCompare(f_jax, values, indices)

  def test_gather_rank_change(self):
    params = jnp.array([[1.0, 1.5, 2.0], [2.0, 2.5, 3.0], [3.0, 3.5, 4.0]])
    indices = jnp.array([[1, 1, 2], [0, 1, 0]])
    f_jax = jax.jit(lambda i: params[i])
    self.ConvertAndCompare(f_jax, indices)

  @parameterized.named_parameters(jtu.cases_from_list(
    dict(testcase_name=f"_{f_jax.__name__}",
         f_jax=f_jax)
    for f_jax in REDUCE))
  def test_reduce_ops_with_numerical_input(self, f_jax):
    values = np.array([1, 2, 3], dtype=np.float32)
    self.ConvertAndCompare(f_jax, values)

  @parameterized.named_parameters(jtu.cases_from_list(
    dict(testcase_name=f"_{f_jax.__name__}",
         f_jax=f_jax)
    for f_jax in (jnp.cumsum, jnp.cumprod)))
  def test_cumulated_ops(self, f_jax):
    values = np.array([1, 2, 3], dtype=np.float32)
    self.ConvertAndCompare(f_jax, values)

  @parameterized.named_parameters(jtu.cases_from_list(
    dict(testcase_name=f"_{op.__name__}",
         op=op)
    for op in INDEX))
  def test_scatter_static(self, op):
    values = np.ones((5, 6), dtype=np.float32)
    update = np.float32(6.)
    f_jax = jax.jit(lambda v, u: op(v, jax.ops.index[::2, 3:], u))
    self.ConvertAndCompare(f_jax, values, update)

  @parameterized.named_parameters(jtu.cases_from_list(
    dict(testcase_name=f"_{f_jax.__name__}",
         f_jax=f_jax)
    for f_jax in REDUCE))
  def test_reduce_ops_with_boolean_input(self, f_jax):
    values = np.array([True, False, True], dtype=np.bool_)
    self.ConvertAndCompare(f_jax, values)

  @primitive_harness.parameterized(primitive_harness.random_gamma)
  def test_random_gamma(self, harness: primitive_harness.Harness):
    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
                           rtol=1e-5)

  @primitive_harness.parameterized(primitive_harness.random_split)
  def test_random_split(self, harness: primitive_harness.Harness):
    self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))

  def test_zeros_like(self):
    v = np.float32(2.)
    f_jax = jax.ad_util.zeros_like_jaxval
    self.ConvertAndCompare(f_jax, v)

  def test_stop_gradient(self):
    f = jax2tf.convert(lax.stop_gradient)
    self.assertEqual(f(tf.ones([])), 1.)

  # test_bfloat16_constant checks that https://github.com/google/jax/issues/3942 is
  # fixed
  def test_bfloat16_constant(self):
    def jax_fn_scalar(x):
      x = x.astype(jnp.bfloat16)
      x *= 2.
      return x

    def jax_fn_array(x):
      x = x.astype(jnp.bfloat16)
      x *= np.array([1.5, 2.5, 3.5], jnp.bfloat16)
      return x

    tf_fn_scalar = jax2tf.convert(jax_fn_scalar)
    self.assertAllClose(tf_fn_scalar(1.375).numpy(), jnp.bfloat16(2.750))

    tf_fn_array = jax2tf.convert(jax_fn_array)
    self.assertAllClose(tf_fn_array(np.array([3, 4, 5])),
                        np.array([4.5, 10, 17.5], jnp.bfloat16))
Ejemplo n.º 22
0
class LaxBackedScipyTests(jtu.JaxTestCase):
    def _fetch_preconditioner(self, preconditioner, A, rng=None):
        """
    Returns one of various preconditioning matrices depending on the identifier
    `preconditioner' and the input matrix A whose inverse it supposedly
    approximates.
    """
        if preconditioner == 'identity':
            M = np.eye(A.shape[0], dtype=A.dtype)
        elif preconditioner == 'random':
            if rng is None:
                rng = jtu.rand_default(self.rng())
            M = np.linalg.inv(rand_sym_pos_def(rng, A.shape, A.dtype))
        elif preconditioner == 'exact':
            M = np.linalg.inv(A)
        else:
            M = None
        return M

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_preconditioner={}".format(
                jtu.format_shape_dtype_string(shape, dtype), preconditioner),
            "shape":
            shape,
            "dtype":
            dtype,
            "preconditioner":
            preconditioner
        } for shape in [(4, 4), (7, 7)]
                            for dtype in [np.float64, np.complex128]
                            for preconditioner in
                            [None, 'identity', 'exact', 'random']))
    def test_cg_against_scipy(self, shape, dtype, preconditioner):
        if not config.x64_enabled:
            raise unittest.SkipTest("requires x64 mode")

        rng = jtu.rand_default(self.rng())
        A = rand_sym_pos_def(rng, shape, dtype)
        b = rng(shape[:1], dtype)
        M = self._fetch_preconditioner(preconditioner, A, rng=rng)

        def args_maker():
            return A, b

        self._CheckAgainstNumpy(partial(scipy_cg, M=M, maxiter=1),
                                partial(lax_cg, M=M, maxiter=1),
                                args_maker,
                                tol=1e-12)

        self._CheckAgainstNumpy(partial(scipy_cg, M=M, maxiter=3),
                                partial(lax_cg, M=M, maxiter=3),
                                args_maker,
                                tol=1e-12)

        self._CheckAgainstNumpy(np.linalg.solve,
                                partial(lax_cg, M=M, atol=1e-10),
                                args_maker,
                                tol=1e-6)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype
        } for shape in [(2, 2)] for dtype in float_types + complex_types))
    def test_cg_as_solve(self, shape, dtype):

        rng = jtu.rand_default(self.rng())
        a = rng(shape, dtype)
        b = rng(shape[:1], dtype)

        expected = np.linalg.solve(posify(a), b)
        actual = lax_cg(posify(a), b)
        self.assertAllClose(expected, actual, atol=1e-5, rtol=1e-5)

        actual = jit(lax_cg)(posify(a), b)
        self.assertAllClose(expected, actual, atol=1e-5, rtol=1e-5)

        # numerical gradients are only well defined if ``a`` is guaranteed to be
        # positive definite.
        jtu.check_grads(lambda x, y: lax_cg(posify(x), y), (a, b),
                        order=2,
                        rtol=2e-1)

    def test_cg_ndarray(self):
        A = lambda x: 2 * x
        b = jnp.arange(9.0).reshape((3, 3))
        expected = b / 2
        actual, _ = jax.scipy.sparse.linalg.cg(A, b)
        self.assertAllClose(expected, actual)

    def test_cg_pytree(self):
        A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]}
        b = {"a": 1.0, "b": -4.0}
        expected = {"a": 4.0, "b": -6.0}
        actual, _ = jax.scipy.sparse.linalg.cg(A, b)
        self.assertEqual(expected.keys(), actual.keys())
        self.assertAlmostEqual(expected["a"], actual["a"], places=6)
        self.assertAlmostEqual(expected["b"], actual["b"], places=6)

    def test_cg_errors(self):
        A = lambda x: x
        b = jnp.zeros((2, ))
        with self.assertRaisesRegex(
                ValueError, "x0 and b must have matching tree structure"):
            jax.scipy.sparse.linalg.cg(A, {'x': b}, {'y': b})
        with self.assertRaisesRegex(ValueError,
                                    "x0 and b must have matching shape"):
            jax.scipy.sparse.linalg.cg(A, b, b[:, np.newaxis])
        with self.assertRaisesRegex(ValueError, "must be a square matrix"):
            jax.scipy.sparse.linalg.cg(jnp.zeros((3, 2)), jnp.zeros((2, )))
        with self.assertRaisesRegex(
                TypeError,
                "linear operator must be either a function or ndarray"):
            jax.scipy.sparse.linalg.cg([[1]], jnp.zeros((1, )))

    def test_cg_without_pytree_equality(self):
        @register_pytree_node_class
        class MinimalPytree:
            def __init__(self, value):
                self.value = value

            def tree_flatten(self):
                return [self.value], None

            @classmethod
            def tree_unflatten(cls, aux_data, children):
                return cls(*children)

        A = lambda x: MinimalPytree(2 * x.value)
        b = MinimalPytree(jnp.arange(5.0))
        expected = b.value / 2
        actual, _ = jax.scipy.sparse.linalg.cg(A, b)
        self.assertAllClose(expected, actual.value)

    # BICGSTAB
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_preconditioner={}".format(
                jtu.format_shape_dtype_string(shape, dtype), preconditioner),
            "shape":
            shape,
            "dtype":
            dtype,
            "preconditioner":
            preconditioner
        } for shape in [(5, 5)] for dtype in [np.float64, np.complex128]
                            for preconditioner in
                            [None, 'identity', 'exact', 'random']))
    def test_bicgstab_against_scipy(self, shape, dtype, preconditioner):
        if not config.jax_enable_x64:
            raise unittest.SkipTest("requires x64 mode")

        rng = jtu.rand_default(self.rng())
        A = rng(shape, dtype)
        b = rng(shape[:1], dtype)
        M = self._fetch_preconditioner(preconditioner, A, rng=rng)

        def args_maker():
            return A, b

        self._CheckAgainstNumpy(partial(scipy_bicgstab, M=M, maxiter=1),
                                partial(lax_bicgstab, M=M, maxiter=1),
                                args_maker,
                                tol=1e-5)

        self._CheckAgainstNumpy(partial(scipy_bicgstab, M=M, maxiter=2),
                                partial(lax_bicgstab, M=M, maxiter=2),
                                args_maker,
                                tol=1e-4)

        self._CheckAgainstNumpy(partial(scipy_bicgstab, M=M, maxiter=1),
                                partial(lax_bicgstab, M=M, maxiter=1),
                                args_maker,
                                tol=1e-4)

        self._CheckAgainstNumpy(np.linalg.solve,
                                partial(lax_bicgstab, M=M, atol=1e-6),
                                args_maker,
                                tol=1e-4)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_preconditioner={}".format(
                jtu.format_shape_dtype_string(shape, dtype), preconditioner),
            "shape":
            shape,
            "dtype":
            dtype,
            "preconditioner":
            preconditioner
        } for shape in [(2, 2), (7, 7)]
                            for dtype in float_types + complex_types
                            for preconditioner in [None, 'identity', 'exact']))
    @jtu.skip_on_devices("gpu")
    def test_bicgstab_on_identity_system(self, shape, dtype, preconditioner):
        A = jnp.eye(shape[1], dtype=dtype)
        solution = jnp.ones(shape[1], dtype=dtype)
        rng = jtu.rand_default(self.rng())
        M = self._fetch_preconditioner(preconditioner, A, rng=rng)
        b = matmul_high_precision(A, solution)
        tol = shape[0] * jnp.finfo(dtype).eps
        x, info = jax.scipy.sparse.linalg.bicgstab(A,
                                                   b,
                                                   tol=tol,
                                                   atol=tol,
                                                   M=M)
        using_x64 = solution.dtype.kind in {np.float64, np.complex128}
        solution_tol = 1e-8 if using_x64 else 1e-4
        self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_preconditioner={}".format(
                jtu.format_shape_dtype_string(shape, dtype), preconditioner),
            "shape":
            shape,
            "dtype":
            dtype,
            "preconditioner":
            preconditioner
        } for shape in [(2, 2), (4, 4)]
                            for dtype in float_types + complex_types
                            for preconditioner in [None, 'identity', 'exact']))
    @jtu.skip_on_devices("gpu")
    def test_bicgstab_on_random_system(self, shape, dtype, preconditioner):
        rng = jtu.rand_default(self.rng())
        A = rng(shape, dtype)
        solution = rng(shape[1:], dtype)
        M = self._fetch_preconditioner(preconditioner, A, rng=rng)
        b = matmul_high_precision(A, solution)
        tol = shape[0] * jnp.finfo(A.dtype).eps
        x, info = jax.scipy.sparse.linalg.bicgstab(A,
                                                   b,
                                                   tol=tol,
                                                   atol=tol,
                                                   M=M)
        using_x64 = solution.dtype.kind in {np.float64, np.complex128}
        solution_tol = 1e-8 if using_x64 else 1e-4
        self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
        # solve = lambda A, b: jax.scipy.sparse.linalg.bicgstab(A, b)[0]
        # jtu.check_grads(solve, (A, b), order=1, rtol=3e-1)

    def test_bicgstab_pytree(self):
        A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]}
        b = {"a": 1.0, "b": -4.0}
        expected = {"a": 4.0, "b": -6.0}
        actual, _ = jax.scipy.sparse.linalg.bicgstab(A, b)
        self.assertEqual(expected.keys(), actual.keys())
        self.assertAlmostEqual(expected["a"], actual["a"], places=5)
        self.assertAlmostEqual(expected["b"], actual["b"], places=5)

    # GMRES
    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_shape={}_preconditioner={}_solve_method={}".format(
                    jtu.format_shape_dtype_string(shape, dtype),
                    preconditioner, solve_method),
                "shape":
                shape,
                "dtype":
                dtype,
                "preconditioner":
                preconditioner,
                "solve_method":
                solve_method
            } for shape in [(3, 3)] for dtype in [np.float64, np.complex128]
            for preconditioner in [None, 'identity', 'exact', 'random']
            for solve_method in ['incremental', 'batched']))
    def test_gmres_against_scipy(self, shape, dtype, preconditioner,
                                 solve_method):
        if not config.x64_enabled:
            raise unittest.SkipTest("requires x64 mode")

        rng = jtu.rand_default(self.rng())
        A = rng(shape, dtype)
        b = rng(shape[:1], dtype)
        M = self._fetch_preconditioner(preconditioner, A, rng=rng)

        def args_maker():
            return A, b

        self._CheckAgainstNumpy(partial(scipy_gmres, M=M, restart=1,
                                        maxiter=1),
                                partial(lax_gmres,
                                        M=M,
                                        restart=1,
                                        maxiter=1,
                                        solve_method=solve_method),
                                args_maker,
                                tol=1e-10)

        self._CheckAgainstNumpy(partial(scipy_gmres, M=M, restart=1,
                                        maxiter=2),
                                partial(lax_gmres,
                                        M=M,
                                        restart=1,
                                        maxiter=2,
                                        solve_method=solve_method),
                                args_maker,
                                tol=1e-10)

        self._CheckAgainstNumpy(partial(scipy_gmres, M=M, restart=2,
                                        maxiter=1),
                                partial(lax_gmres,
                                        M=M,
                                        restart=2,
                                        maxiter=1,
                                        solve_method=solve_method),
                                args_maker,
                                tol=1e-10)

        self._CheckAgainstNumpy(np.linalg.solve,
                                partial(lax_gmres,
                                        M=M,
                                        atol=1e-6,
                                        solve_method=solve_method),
                                args_maker,
                                tol=1e-10)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_preconditioner={}_solve_method={}".format(
                jtu.format_shape_dtype_string(shape, dtype), preconditioner,
                solve_method),
            "shape":
            shape,
            "dtype":
            dtype,
            "preconditioner":
            preconditioner,
            "solve_method":
            solve_method
        } for shape in [(2, 2), (7, 7)]
                            for dtype in float_types + complex_types
                            for preconditioner in [None, 'identity', 'exact']
                            for solve_method in ['batched', 'incremental']))
    @jtu.skip_on_devices("gpu")
    def test_gmres_on_identity_system(self, shape, dtype, preconditioner,
                                      solve_method):
        A = jnp.eye(shape[1], dtype=dtype)

        solution = jnp.ones(shape[1], dtype=dtype)
        rng = jtu.rand_default(self.rng())
        M = self._fetch_preconditioner(preconditioner, A, rng=rng)
        b = matmul_high_precision(A, solution)
        restart = shape[-1]
        tol = shape[0] * jnp.finfo(dtype).eps
        x, info = jax.scipy.sparse.linalg.gmres(A,
                                                b,
                                                tol=tol,
                                                atol=tol,
                                                restart=restart,
                                                M=M,
                                                solve_method=solve_method)
        using_x64 = solution.dtype.kind in {np.float64, np.complex128}
        solution_tol = 1e-8 if using_x64 else 1e-4
        self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_preconditioner={}_solve_method={}".format(
                jtu.format_shape_dtype_string(shape, dtype), preconditioner,
                solve_method),
            "shape":
            shape,
            "dtype":
            dtype,
            "preconditioner":
            preconditioner,
            "solve_method":
            solve_method
        } for shape in [(2, 2), (4, 4)]
                            for dtype in float_types + complex_types
                            for preconditioner in [None, 'identity', 'exact']
                            for solve_method in ['incremental', 'batched']))
    @jtu.skip_on_devices("gpu")
    def test_gmres_on_random_system(self, shape, dtype, preconditioner,
                                    solve_method):
        rng = jtu.rand_default(self.rng())
        A = rng(shape, dtype)

        solution = rng(shape[1:], dtype)
        M = self._fetch_preconditioner(preconditioner, A, rng=rng)
        b = matmul_high_precision(A, solution)
        restart = shape[-1]
        tol = shape[0] * jnp.finfo(A.dtype).eps
        x, info = jax.scipy.sparse.linalg.gmres(A,
                                                b,
                                                tol=tol,
                                                atol=tol,
                                                restart=restart,
                                                M=M,
                                                solve_method=solve_method)
        using_x64 = solution.dtype.kind in {np.float64, np.complex128}
        solution_tol = 1e-8 if using_x64 else 1e-4
        self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol)
        # solve = lambda A, b: jax.scipy.sparse.linalg.gmres(A, b)[0]
        # jtu.check_grads(solve, (A, b), order=1, rtol=2e-1)

    def test_gmres_pytree(self):
        A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]}
        b = {"a": 1.0, "b": -4.0}
        expected = {"a": 4.0, "b": -6.0}
        actual, _ = jax.scipy.sparse.linalg.gmres(A, b)
        self.assertEqual(expected.keys(), actual.keys())
        self.assertAlmostEqual(expected["a"], actual["a"], places=5)
        self.assertAlmostEqual(expected["b"], actual["b"], places=5)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_preconditioner={}".format(
                jtu.format_shape_dtype_string(shape, dtype), preconditioner),
            "shape":
            shape,
            "dtype":
            dtype,
            "preconditioner":
            preconditioner
        } for shape in [(2, 2), (3, 3)]
                            for dtype in float_types + complex_types
                            for preconditioner in [None, 'identity']))
    def test_gmres_arnoldi_step(self, shape, dtype, preconditioner):
        """
    The Arnoldi decomposition within GMRES is correct.
    """
        if not config.x64_enabled:
            raise unittest.SkipTest("requires x64 mode")

        rng = jtu.rand_default(self.rng())
        A = rng(shape, dtype)
        M = self._fetch_preconditioner(preconditioner, A, rng=rng)
        if preconditioner is None:
            M = lambda x: x
        else:
            M = partial(matmul_high_precision, M)
        n = shape[0]
        x0 = rng(shape[:1], dtype)
        Q = np.zeros((n, n + 1), dtype=dtype)
        Q[:, 0] = x0 / jnp.linalg.norm(x0)
        Q = jnp.array(Q)
        H = jnp.eye(n, n + 1, dtype=dtype)

        @jax.tree_util.Partial
        def A_mv(x):
            return matmul_high_precision(A, x)

        for k in range(n):
            Q, H, _ = jax._src.scipy.sparse.linalg._kth_arnoldi_iteration(
                k, A_mv, M, Q, H)
        QA = matmul_high_precision(Q[:, :n].conj().T, A)
        QAQ = matmul_high_precision(QA, Q[:, :n])
        self.assertAllClose(QAQ, H.T[:n, :], rtol=1e-5, atol=1e-5)
Ejemplo n.º 23
0
class EnergyTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype
            } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
    def test_simple_spring(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)
        disp, _ = space.free()
        if spatial_dimension == 2:
            R = np.array([[0., 0.], [1., 1.]], dtype=dtype)
            dist = np.sqrt(2.)
        elif spatial_dimension == 3:
            R = np.array([[0., 0., 0.], [1., 1., 1.]], dtype=dtype)
            dist = np.sqrt(3.)
        bonds = np.array([[0, 1]], np.int32)
        for _ in range(STOCHASTIC_SAMPLES):
            key, l_key, a_key = random.split(key, 3)
            length = random.uniform(key, (), minval=0.1, maxval=3.0)
            alpha = random.uniform(key, (), minval=2., maxval=4.)
            E = energy.simple_spring_bond(disp,
                                          bonds,
                                          length=length,
                                          alpha=alpha)
            E_exact = dtype((dist - length)**alpha / alpha)
            self.assertAllClose(E(R), E_exact, True)

    # pylint: disable=g-complex-comprehension
    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name':
            '_dim={}_alpha={}_dtype={}'.format(dim, alpha, dtype.__name__),
            'spatial_dimension':
            dim,
            'alpha':
            alpha,
            'dtype':
            dtype
        } for dim in SPATIAL_DIMENSION for alpha in SOFT_SPHERE_ALPHA
                            for dtype in POSITION_DTYPE))
    def test_soft_sphere(self, spatial_dimension, alpha, dtype):
        key = random.PRNGKey(0)
        alpha = f32(alpha)
        for _ in range(STOCHASTIC_SAMPLES):
            key, split_sigma, split_epsilon = random.split(key, 3)
            sigma = np.array(random.uniform(split_sigma, (1, ),
                                            minval=0.0,
                                            maxval=3.0)[0],
                             dtype=dtype)
            epsilon = np.array(random.uniform(split_epsilon, (1, ),
                                              minval=0.0,
                                              maxval=4.0)[0],
                               dtype=dtype)
            self.assertAllClose(
                energy.soft_sphere(dtype(0), sigma, epsilon, alpha),
                epsilon / alpha, True)
            self.assertAllClose(
                energy.soft_sphere(dtype(sigma), sigma, epsilon, alpha),
                np.array(0.0, dtype=dtype), True)

            if alpha == 3.0:
                grad_energy = grad(energy.soft_sphere)
                g = grad_energy(dtype(sigma), sigma, epsilon, alpha)
                self.assertAllClose(g, np.array(0, dtype=dtype), True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype
            } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
    def test_lennard_jones(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split_sigma, split_epsilon = random.split(key, 3)
            sigma = dtype(
                random.uniform(split_sigma, (1, ), minval=0.5, maxval=3.0)[0])
            epsilon = dtype(
                random.uniform(split_epsilon, (1, ), minval=0.0,
                               maxval=4.0)[0])
            dr = dtype(sigma * 2**(1.0 / 6.0))
            self.assertAllClose(energy.lennard_jones(dr, sigma, epsilon),
                                np.array(-epsilon, dtype=dtype), True)
            g = grad(energy.lennard_jones)(dr, sigma, epsilon)
            self.assertAllClose(g, np.array(0, dtype=dtype), True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype
            } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
    def test_isotropic_cutoff(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split_rs, split_rl, split_sigma, split_epsilon = random.split(
                key, 5)
            sigma = f32(
                random.uniform(split_sigma, (1, ), minval=0.5, maxval=3.0)[0])
            epsilon = f32(
                random.uniform(split_epsilon, (1, ), minval=0.0,
                               maxval=4.0)[0])
            r_small = random.uniform(split_rs, (10, ),
                                     minval=0.0,
                                     maxval=2.0 * sigma,
                                     dtype=dtype)
            r_large = random.uniform(split_rl, (10, ),
                                     minval=2.5 * sigma,
                                     maxval=3.0 * sigma,
                                     dtype=dtype)

            r_onset = f32(2.0 * sigma)
            r_cutoff = f32(2.5 * sigma)

            E = energy.multiplicative_isotropic_cutoff(energy.lennard_jones,
                                                       r_onset, r_cutoff)

            self.assertAllClose(E(r_small, sigma, epsilon),
                                energy.lennard_jones(r_small, sigma, epsilon),
                                True)
            self.assertAllClose(E(r_large, sigma, epsilon),
                                np.zeros_like(r_large, dtype=dtype), True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype,
            } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
    def test_soft_sphere_neighbor_list_energy(self, spatial_dimension, dtype):
        key = random.PRNGKey(1)

        box_size = f32(15.0)
        displacement, _ = space.periodic(box_size)
        exact_energy_fn = energy.soft_sphere_pair(displacement)

        R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension),
                                      dtype=dtype)
        neighbor_fn, energy_fn = energy.soft_sphere_neighbor_list(
            displacement, box_size, R)

        idx = neighbor_fn(R)

        self.assertAllClose(np.array(exact_energy_fn(R), dtype=dtype),
                            energy_fn(R, idx), True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype,
            } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
    def test_lennard_jones_cell_neighbor_list_energy(self, spatial_dimension,
                                                     dtype):
        key = random.PRNGKey(1)

        box_size = f32(15)
        displacement, _ = space.periodic(box_size)
        metric = space.metric(displacement)
        exact_energy_fn = energy.lennard_jones_pair(displacement)

        R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension),
                                      dtype=dtype)
        neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(
            displacement, box_size, R)

        idx = neighbor_fn(R)
        self.assertAllClose(np.array(exact_energy_fn(R), dtype=dtype),
                            energy_fn(R, idx), True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name': '_dim={}_dtype={}'.format(
                    dim, dtype.__name__),
                'spatial_dimension': dim,
                'dtype': dtype,
            } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE))
    def test_lennard_jones_neighbor_list_force(self, spatial_dimension, dtype):
        key = random.PRNGKey(1)

        box_size = f32(15.0)
        displacement, _ = space.periodic(box_size)
        metric = space.metric(displacement)
        exact_force_fn = quantity.force(
            energy.lennard_jones_pair(displacement))

        R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension),
                                      dtype=dtype)
        neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list(
            displacement, box_size, R)
        force_fn = quantity.force(energy_fn)

        idx = neighbor_fn(R)
        self.assertAllClose(np.array(exact_force_fn(R), dtype=dtype),
                            force_fn(R, idx), True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name':
            '_num_reps={}_dtype={}'.format(num_repetitions, dtype.__name__),
            'num_repetitions':
            num_repetitions,
            'dtype':
            dtype,
        } for num_repetitions in UNIT_CELL_SIZE for dtype in POSITION_DTYPE))
    def test_eam(self, num_repetitions, dtype):
        latvec = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]],
                          dtype=dtype) * f32(4.05 / 2)
        atoms = np.array([[0, 0, 0]], dtype=dtype)
        atoms_repeated, latvec_repeated = lattice_repeater(
            atoms, latvec, num_repetitions)
        inv_latvec = np.array(onp.linalg.inv(onp.array(latvec_repeated)),
                              dtype=dtype)
        displacement, _ = space.periodic_general(latvec_repeated)
        charge_fn, embedding_fn, pairwise_fn = make_eam_test_splines()
        assert charge_fn(np.array(1.0, dtype)).dtype == dtype
        assert embedding_fn(np.array(1.0, dtype)).dtype == dtype
        assert pairwise_fn(np.array(1.0, dtype)).dtype == dtype
        eam_energy = energy.eam(displacement, charge_fn, embedding_fn,
                                pairwise_fn)
        tol = 1e-5 if dtype == np.float32 else 1e-6
        self.assertAllClose(
            eam_energy(np.dot(atoms_repeated, inv_latvec)) /
            np.array(num_repetitions**3, dtype), dtype(-3.363338), True, tol,
            tol)
Ejemplo n.º 24
0
class LaxBackedScipyTests(jtu.JaxTestCase):
  """Tests for LAX-backed Scipy implementation."""

  def _GetArgsMaker(self, rng, shapes, dtypes):
    return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shapes={}_axis={}_keepdims={}_return_sign={}_use_b_{}".format(
          jtu.format_shape_dtype_string(shapes, dtype),
          axis, keepdims, return_sign, use_b),
       # TODO(b/133842870): re-enable when exp(nan) returns NaN on CPU.
       "shapes": shapes, "dtype": dtype,
       "axis": axis, "keepdims": keepdims,
       "return_sign": return_sign, "use_b": use_b}
      for shape_group in compatible_shapes for dtype in float_dtypes + complex_dtypes + int_dtypes
      for use_b in [False, True]
      for shapes in itertools.product(*(
        (shape_group, shape_group) if use_b else (shape_group,)))
      for axis in range(-max(len(shape) for shape in shapes),
                         max(len(shape) for shape in shapes))
      for keepdims in [False, True]
      for return_sign in [False, True]))
  @jtu.ignore_warning(category=RuntimeWarning,
                      message="invalid value encountered in .*")
  def testLogSumExp(self, shapes, dtype, axis,
                    keepdims, return_sign, use_b):
    if jtu.device_under_test() != "cpu":
      rng = jtu.rand_some_inf_and_nan(self.rng())
    else:
      rng = jtu.rand_default(self.rng())
    # TODO(mattjj): test autodiff
    if use_b:
      def scipy_fun(array_to_reduce, scale_array):
        return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
                                     return_sign=return_sign, b=scale_array)

      def lax_fun(array_to_reduce, scale_array):
        return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
                                     return_sign=return_sign, b=scale_array)

      args_maker = lambda: [rng(shapes[0], dtype), rng(shapes[1], dtype)]
    else:
      def scipy_fun(array_to_reduce):
        return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
                                     return_sign=return_sign)

      def lax_fun(array_to_reduce):
        return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
                                     return_sign=return_sign)

      args_maker = lambda: [rng(shapes[0], dtype)]
    tol = {np.float32: 1E-6, np.float64: 1E-14}
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker)
    self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)

  def testLogSumExpZeros(self):
    # Regression test for https://github.com/google/jax/issues/5370
    scipy_fun = lambda a, b: osp_special.logsumexp(a, b=b)
    lax_fun = lambda a, b: lsp_special.logsumexp(a, b=b)
    args_maker = lambda: [np.array([-1000, -2]), np.array([1, 0])]
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker)
    self._CompileAndCheck(lax_fun, args_maker)

  @parameterized.named_parameters(itertools.chain.from_iterable(
    jtu.cases_from_list(
        {"testcase_name": jtu.format_test_name_suffix(
            rec.test_name, shapes, dtypes),
         "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes,
         "test_autodiff": rec.test_autodiff,
         "nondiff_argnums": rec.nondiff_argnums,
         "scipy_op": getattr(osp_special, rec.name),
         "lax_op": getattr(lsp_special, rec.name)}
        for shapes in itertools.combinations_with_replacement(all_shapes, rec.nargs)
        for dtypes in (itertools.combinations_with_replacement(rec.dtypes, rec.nargs)
          if isinstance(rec.dtypes, list) else itertools.product(*rec.dtypes)))
      for rec in JAX_SPECIAL_FUNCTION_RECORDS))
  def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes, dtypes,
                          test_autodiff, nondiff_argnums):
    if (jtu.device_under_test() == "cpu" and
        (lax_op is lsp_special.gammainc or lax_op is lsp_special.gammaincc)):
      # TODO(b/173608403): re-enable test when LLVM bug is fixed.
      raise unittest.SkipTest("Skipping test due to LLVM lowering bug")
    rng = rng_factory(self.rng())
    args_maker = self._GetArgsMaker(rng, shapes, dtypes)
    args = args_maker()
    self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3,
                        check_dtypes=False)
    self._CompileAndCheck(lax_op, args_maker, rtol=1e-4)

    if test_autodiff:
      def partial_lax_op(*vals):
        list_args = list(vals)
        for i in nondiff_argnums:
          list_args.insert(i, args[i])
        return lax_op(*list_args)

      assert list(nondiff_argnums) == sorted(set(nondiff_argnums))
      diff_args = [x for i, x in enumerate(args) if i not in nondiff_argnums]
      jtu.check_grads(partial_lax_op, diff_args, order=1,
                      atol=jtu.if_device_under_test("tpu", .1, 1e-3),
                      rtol=.1, eps=1e-3)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_d={}".format(
          jtu.format_shape_dtype_string(shape, dtype), d),
       "shape": shape, "dtype": dtype, "d": d}
      for shape in all_shapes
      for dtype in float_dtypes
      for d in [1, 2, 5]))
  def testMultigammaln(self, shape, dtype, d):
    def scipy_fun(a):
      return osp_special.multigammaln(a, d)

    def lax_fun(a):
      return lsp_special.multigammaln(a, d)

    rng = jtu.rand_positive(self.rng())
    args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.]
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,
                            tol={np.float32: 1e-3, np.float64: 1e-14})
    self._CompileAndCheck(lax_fun, args_maker)

  def testIssue980(self):
    x = np.full((4,), -1e20, dtype=np.float32)
    self.assertAllClose(np.zeros((4,), dtype=np.float32),
                        lsp_special.expit(x))

  def testIssue3758(self):
    x = np.array([1e5, 1e19, 1e10], dtype=np.float32)
    q = np.array([1., 40., 30.], dtype=np.float32)
    self.assertAllClose(np.array([1., 0., 0.], dtype=np.float32), lsp_special.zeta(x, q))

  def testXlogyShouldReturnZero(self):
    self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False)

  def testGradOfXlogyAtZero(self):
    partial_xlogy = functools.partial(lsp_special.xlogy, 0.)
    self.assertAllClose(api.grad(partial_xlogy)(0.), 0., check_dtypes=False)

  def testXlog1pyShouldReturnZero(self):
    self.assertAllClose(lsp_special.xlog1py(0., -1.), 0., check_dtypes=False)

  def testGradOfXlog1pyAtZero(self):
    partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.)
    self.assertAllClose(api.grad(partial_xlog1py)(-1.), 0., check_dtypes=False)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_maxdegree={}_inputsize={}".format(l_max, num_z),
       "l_max": l_max,
       "num_z": num_z}
       for l_max, num_z in zip([1, 2, 3], [6, 7, 8])))
  def testLpmn(self, l_max, num_z):
    # Points on which the associated Legendre functions areevaluated.
    z = np.linspace(-0.2, 0.9, num_z)
    actual_p_vals, actual_p_derivatives = lsp_special.lpmn(m=l_max, n=l_max, z=z)

    # The expected results are obtained from scipy.
    expected_p_vals = np.zeros((l_max + 1, l_max + 1, num_z))
    expected_p_derivatives = np.zeros((l_max + 1, l_max + 1, num_z))

    for i in range(num_z):
      val, derivative = osp_special.lpmn(l_max, l_max, z[i])
      expected_p_vals[:, :, i] = val
      expected_p_derivatives[:, :, i] = derivative

    with self.subTest('Test values.'):
      self.assertAllClose(actual_p_vals, expected_p_vals, rtol=1e-6, atol=3.2e-6)

    with self.subTest('Test derivatives.'):
      self.assertAllClose(actual_p_derivatives,expected_p_derivatives,
              rtol=1e-6, atol=8.4e-4)
Ejemplo n.º 25
0
class ScipyLinalgTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for shape in [(1, 1), (4, 5), (10, 5), (50, 50)]
                            for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    def testLu(self, shape, dtype, rng):
        _skip_if_unsupported_type(dtype)
        args_maker = lambda: [rng(shape, dtype)]
        x, = args_maker()
        p, l, u = jsp.linalg.lu(x)
        self.assertAllClose(x,
                            onp.matmul(p, onp.matmul(l, u)),
                            check_dtypes=True)
        self._CompileAndCheck(jsp.linalg.lu, args_maker, check_dtypes=True)

    def testLuOfSingularMatrix(self):
        x = np.array([[-1., 3. / 2], [2. / 3, -1.]], dtype=onp.float32)
        p, l, u = jsp.linalg.lu(x)
        self.assertAllClose(x,
                            onp.matmul(p, onp.matmul(l, u)),
                            check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for shape in [(1, 1), (4, 5), (10, 5), (10, 10), (6, 7, 7)]
                            for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    @jtu.skip_on_devices("tpu")  # TODO(phawkins): precision problems on TPU.
    def testLuGrad(self, shape, dtype, rng):
        _skip_if_unsupported_type(dtype)
        a = rng(shape, dtype)
        lu = vmap(jsp.linalg.lu) if len(shape) > 2 else jsp.linalg.lu
        jtu.check_grads(lu, (a, ), 2, atol=5e-2, rtol=1e-1)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for shape in [(4, 5), (6, 5)] for dtype in [np.float32]
                            for rng in [jtu.rand_default()]))
    def testLuBatching(self, shape, dtype, rng):
        _skip_if_unsupported_type(dtype)
        args = [rng(shape, np.float32) for _ in range(10)]
        expected = list(osp.linalg.lu(x) for x in args)
        ps = onp.stack([out[0] for out in expected])
        ls = onp.stack([out[1] for out in expected])
        us = onp.stack([out[2] for out in expected])

        actual_ps, actual_ls, actual_us = vmap(jsp.linalg.lu)(np.stack(args))
        self.assertAllClose(ps, actual_ps, check_dtypes=True)
        self.assertAllClose(ls, actual_ls, check_dtypes=True)
        self.assertAllClose(us, actual_us, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_n={}".format(jtu.format_shape_dtype_string((n, n), dtype)),
            "n":
            n,
            "dtype":
            dtype,
            "rng":
            rng
        } for n in [1, 4, 5, 200] for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    def testLuFactor(self, n, dtype, rng):
        _skip_if_unsupported_type(dtype)
        args_maker = lambda: [rng((n, n), dtype)]

        x, = args_maker()
        lu, piv = jsp.linalg.lu_factor(x)
        l = onp.tril(lu, -1) + onp.eye(n, dtype=dtype)
        u = onp.triu(lu)
        for i in range(n):
            x[[i, piv[i]], ] = x[[piv[i], i], ]
        self.assertAllClose(x, onp.matmul(l, u), check_dtypes=True, rtol=1e-3)
        self._CompileAndCheck(jsp.linalg.lu_factor,
                              args_maker,
                              check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_lhs={}_rhs={}_trans={}".format(
                jtu.format_shape_dtype_string(lhs_shape, dtype),
                jtu.format_shape_dtype_string(rhs_shape, dtype), trans),
            "lhs_shape":
            lhs_shape,
            "rhs_shape":
            rhs_shape,
            "dtype":
            dtype,
            "trans":
            trans,
            "rng":
            rng
        } for lhs_shape, rhs_shape in [
            ((1, 1), (1, 1)),
            ((4, 4), (4, )),
            ((8, 8), (8, 4, 2)),
        ] for trans in [0, 1, 2] for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    def testLuSolve(self, lhs_shape, rhs_shape, dtype, trans, rng):
        _skip_if_unsupported_type(dtype)
        osp_fun = lambda lu, piv, rhs: osp.linalg.lu_solve(
            (lu, piv), rhs, trans=trans)
        jsp_fun = lambda lu, piv, rhs: jsp.linalg.lu_solve(
            (lu, piv), rhs, trans=trans)

        def args_maker():
            a = rng(lhs_shape, dtype)
            lu, piv = osp.linalg.lu_factor(a)
            return [lu, piv, rng(rhs_shape, dtype)]

        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-3)
        self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_lhs={}_rhs={}_sym_pos={}_lower={}".format(
                jtu.format_shape_dtype_string(lhs_shape, dtype),
                jtu.format_shape_dtype_string(rhs_shape, dtype), sym_pos,
                lower),
            "lhs_shape":
            lhs_shape,
            "rhs_shape":
            rhs_shape,
            "dtype":
            dtype,
            "sym_pos":
            sym_pos,
            "lower":
            lower,
            "rng":
            rng
        } for lhs_shape, rhs_shape in [
            ((1, 1), (1, 1)),
            ((4, 4), (4, )),
            ((8, 8), (8, 4)),
        ] for sym_pos, lower in [
            (False, False),
            (True, False),
            (True, True),
        ] for dtype in float_types + complex_types
                            for rng in [jtu.rand_default()]))
    def testSolve(self, lhs_shape, rhs_shape, dtype, sym_pos, lower, rng):
        _skip_if_unsupported_type(dtype)
        if (sym_pos and onp.issubdtype(dtype, onp.complexfloating)
                and jtu.device_under_test() == "tpu"):
            raise unittest.SkipTest(
                "Complex Cholesky decomposition not implemented on TPU")
        osp_fun = lambda lhs, rhs: osp.linalg.solve(
            lhs, rhs, sym_pos=sym_pos, lower=lower)
        jsp_fun = lambda lhs, rhs: jsp.linalg.solve(
            lhs, rhs, sym_pos=sym_pos, lower=lower)

        def args_maker():
            a = rng(lhs_shape, dtype)
            if sym_pos:
                a = onp.matmul(a, onp.conj(T(a)))
                a = onp.tril(a) if lower else onp.triu(a)
            return [a, rng(rhs_shape, dtype)]

        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                check_dtypes=True,
                                tol=1e-3)
        self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_lhs={}_rhs={}_lower={}_transposea={}_unit_diagonal={}".format(
                jtu.format_shape_dtype_string(lhs_shape, dtype),
                jtu.format_shape_dtype_string(rhs_shape, dtype), lower,
                transpose_a, unit_diagonal),
            "lower":
            lower,
            "transpose_a":
            transpose_a,
            "unit_diagonal":
            unit_diagonal,
            "lhs_shape":
            lhs_shape,
            "rhs_shape":
            rhs_shape,
            "dtype":
            dtype,
            "rng":
            rng
        } for lower in [False, True] for transpose_a in [False, True]
                            for unit_diagonal in [False, True]
                            for lhs_shape, rhs_shape in [
                                ((4, 4), (4, )),
                                ((4, 4), (4, 3)),
                                ((2, 8, 8), (2, 8, 10)),
                            ] for dtype in float_types
                            for rng in [jtu.rand_default()]))
    def testSolveTriangular(self, lower, transpose_a, unit_diagonal, lhs_shape,
                            rhs_shape, dtype, rng):
        _skip_if_unsupported_type(dtype)
        k = rng(lhs_shape, dtype)
        l = onp.linalg.cholesky(
            onp.matmul(k, T(k)) + lhs_shape[-1] * onp.eye(lhs_shape[-1]))
        l = l.astype(k.dtype)
        b = rng(rhs_shape, dtype)

        if unit_diagonal:
            a = onp.tril(l, -1) + onp.eye(lhs_shape[-1], dtype=dtype)
        else:
            a = l
        a = a if lower else T(a)

        inv = onp.linalg.inv(T(a) if transpose_a else a).astype(a.dtype)
        if len(lhs_shape) == len(rhs_shape):
            onp_ans = onp.matmul(inv, b)
        else:
            onp_ans = onp.einsum("...ij,...j->...i", inv, b)

        # The standard scipy.linalg.solve_triangular doesn't support broadcasting.
        # But it seems like an inevitable extension so we support it.
        ans = jsp.linalg.solve_triangular(l if lower else T(l),
                                          b,
                                          trans=1 if transpose_a else 0,
                                          lower=lower,
                                          unit_diagonal=unit_diagonal)

        self.assertAllClose(onp_ans, ans, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_lhs={}_rhs={}_lower={}_transposea={}_unit_diagonal={}".
                format(jtu.format_shape_dtype_string(lhs_shape, dtype),
                       jtu.format_shape_dtype_string(rhs_shape, dtype), lower,
                       transpose_a, unit_diagonal),
                "lower":
                lower,
                "transpose_a":
                transpose_a,
                "unit_diagonal":
                unit_diagonal,
                "lhs_shape":
                lhs_shape,
                "rhs_shape":
                rhs_shape,
                "dtype":
                dtype,
                "rng":
                rng
            } for lower in [False, True] for unit_diagonal in [False, True]
            for dtype in float_types + complex_types for transpose_a in (
                [0, 1] if onp.issubdtype(dtype, np.floating) else [0, 1, 2])
            for lhs_shape, rhs_shape in [
                ((4, 4), (4, )),
                ((4, 4), (4, 3)),
                ((2, 8, 8), (2, 8, 10)),
            ] for rng in [jtu.rand_default()]))
    @jtu.skip_on_devices("tpu")  # TODO(phawkins): Test fails on TPU.
    def testSolveTriangularGrad(self, lower, transpose_a, unit_diagonal,
                                lhs_shape, rhs_shape, dtype, rng):
        _skip_if_unsupported_type(dtype)
        A = np.tril(
            rng(lhs_shape, dtype) + 5 * onp.eye(lhs_shape[-1], dtype=dtype))
        A = A if lower else T(A)
        B = rng(rhs_shape, dtype)
        f = partial(jsp.linalg.solve_triangular,
                    lower=lower,
                    trans=transpose_a,
                    unit_diagonal=unit_diagonal)
        jtu.check_grads(f, (A, B), 2, rtol=2e-2, eps=1e-3)
Ejemplo n.º 26
0
class QuantityTest(jtu.JaxTestCase):
    def test_canonicalize_mass(self):
        assert quantity.canonicalize_mass(3.0) == 3.0
        assert quantity.canonicalize_mass(f32(3.0)) == f32(3.0)
        assert quantity.canonicalize_mass(f64(3.0)) == f64(3.0)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name': '_dim={}'.format(dim),
            'spatial_dimension': dim,
        } for dim in SPATIAL_DIMENSION))
    def test_grad_kinetic_energy(self, spatial_dimension):
        key = random.PRNGKey(0)

        @jit
        def do_fn(theta):
            key = random.PRNGKey(0)
            V = random.normal(key, (PARTICLE_COUNT, spatial_dimension),
                              dtype=f32)

            return quantity.kinetic_energy(theta * V)

        grad(do_fn)(2.0)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name': '_dtype={}'.format(dtype.__name__),
            'dtype': dtype,
        } for dtype in DTYPES))
    def test_cosine_angles(self, dtype):
        displacement, _ = space.free()
        displacement = space.map_product(displacement)
        R = np.array([[0, 0], [0, 1], [1, 1]], dtype=dtype)
        dR = displacement(R, R)
        cangles = quantity.cosine_angles(dR)
        c45 = 1 / np.sqrt(2)
        true_cangles = np.array([[[0, 0, 0], [0, 1, c45], [0, c45, 1]],
                                 [[1, 0, 0], [0, 0, 0], [0, 0, 1]],
                                 [[1, c45, 0], [c45, 1, 0], [0, 0, 0]]],
                                dtype=dtype)
        self.assertAllClose(cangles, true_cangles)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name': '_dtype={}'.format(dtype.__name__),
            'dtype': dtype,
        } for dtype in DTYPES))
    def test_cosine_angles_neighbors(self, dtype):
        displacement, _ = space.free()
        displacement = vmap(vmap(displacement, (None, 0)), 0)

        R = np.array([[0, 0], [0, 1], [1, 1]], dtype=dtype)
        R_neigh = np.array(
            [[[0, 1], [1, 1]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]],
            dtype=dtype)

        dR = displacement(R, R_neigh)

        cangles = quantity.cosine_angles(dR)
        c45 = 1 / np.sqrt(2)
        true_cangles = np.array(
            [[[1, c45], [c45, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
            dtype=dtype)
        self.assertAllClose(cangles, true_cangles)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name': '_dtype={}'.format(dtype.__name__),
            'dtype': dtype,
        } for dtype in DTYPES))
    def test_pair_correlation(self, dtype):
        displacement = lambda Ra, Rb, **kwargs: Ra - Rb
        R = np.array([[1, 0], [0, 0], [0, 1]], dtype=dtype)
        rs = np.linspace(0, 2, 60, dtype=dtype)
        g = quantity.pair_correlation(displacement, rs, f32(0.1))
        gs = g(R)
        gs = np.mean(gs, axis=0)
        assert np.argmax(gs) == np.argmin((rs - 1.)**2)
        assert gs.dtype == dtype
Ejemplo n.º 27
0
class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):

    # This test runs for all primitive harnesses. For each primitive "xxx" the
    # test will be called "test_prim_xxx_..." and the custom parameters for
    # the test are defined in the class method "jax2tf_limitations.Jax2TfLimitation.xxx".
    # See more details in the comment at top of file and in Jax2TfLimitation class.
    # If you want to run this test for only one harness, add parameter
    # `one_containing="foo"` to parameterized below.
    @primitive_harness.parameterized(primitive_harness.all_harnesses,
                                     include_jax_unimpl=False)
    @jtu.ignore_warning(category=UserWarning,
                        message="Using reduced precision for gradient.*")
    def test_prim(self, harness: primitive_harness.Harness):
        limitations = Jax2TfLimitation.limitations_for_harness(harness)
        device = jtu.device_under_test()
        limitations = tuple(
            filter(lambda l: l.filter(device=device, dtype=harness.dtype),
                   limitations))
        func_jax = harness.dyn_fun
        args = harness.dyn_args_maker(self.rng())
        enable_xla = harness.params.get("enable_xla", True)
        self.ConvertAndCompare(func_jax,
                               *args,
                               limitations=limitations,
                               enable_xla=enable_xla)

    def test_primitive_coverage(self):
        """Fail if there are JAX primitives that are not implemented."""
        # Harvest primitives from XLA translation tables
        all_primitives = (set(xla.translations)
                          | set(xla.backend_specific_translations["cpu"])
                          | set(xla.backend_specific_translations["gpu"])
                          | set(xla.backend_specific_translations["tpu"])
                          | set(xla.initial_style_translations)
                          | set(xla.parallel_translations))

        tf_impl = set(jax.experimental.jax2tf.jax2tf.tf_impl) | set(
            jax.experimental.jax2tf.jax2tf.tf_impl_with_avals)
        tf_not_yet_impl = set(jax.experimental.jax2tf.jax2tf.tf_not_yet_impl)

        all_primitives = tuple(sorted(all_primitives, key=str))
        for p in all_primitives:
            if p.name == "axis_index":
                continue
            if p.name in tf_not_yet_impl:
                self.assertNotIn(
                    p, tf_impl
                )  # Should not be in both tf_impl and tf_not_yet_impl
            else:
                self.assertIn(p, tf_impl)

    def test_generate_limitations_doc(self):
        """Generates primitives_with_limited_support.md.

    See the doc for instructions.
    """

        harnesses = [
            h for h in primitive_harness.all_harnesses
            if h.filter(h, include_jax_unimpl=True)
        ]
        print(f"Found {len(harnesses)} test harnesses that work in JAX")

        def unique_hash(h: primitive_harness.Harness, l: Jax2TfLimitation):
            return (h.group_name, l.description, l.devices,
                    tuple([np.dtype(d).name for d in l.dtypes]), l.modes)

        unique_limitations: Dict[Any, Tuple[primitive_harness.Harness,
                                            Jax2TfLimitation]] = {}
        for h in harnesses:
            for l in h.jax_unimplemented:
                if l.enabled:
                    # Fake a Jax2TFLimitation from the Limitation
                    tfl = Jax2TfLimitation(
                        description="Not implemented in JAX: " + l.description,
                        devices=l.devices,
                        dtypes=l.dtypes,
                        expect_tf_error=False,
                        skip_tf_run=True)
                    unique_limitations[hash(unique_hash(h, tfl))] = (h, tfl)
        for h in harnesses:
            for l in Jax2TfLimitation.limitations_for_harness(h):
                unique_limitations[hash(unique_hash(h, l))] = (h, l)

        print(f"Found {len(unique_limitations)} unique limitations")
        tf_error_table = [
            """
| Affected primitive | Description of limitation | Affected dtypes | Affected devices | Affected compilation modes |
| --- | --- | --- | --- | --- |"""
        ]
        tf_numerical_discrepancies_table = list(tf_error_table)  # a copy
        for h, l in sorted(unique_limitations.values(),
                           key=lambda pair: unique_hash(*pair)):
            devices = ", ".join(sorted(l.devices))
            modes = ", ".join(sorted(l.modes))
            description = l.description
            if l.skip_comparison:
                description = "Numeric comparision disabled: " + description
            if l.expect_tf_error:
                description = "TF error: " + description
            if l.skip_tf_run:
                description = "TF test skipped: " + description

            if l.skip_tf_run or l.expect_tf_error:
                to_table = tf_error_table
            elif l.skip_comparison or l.custom_assert:
                to_table = tf_numerical_discrepancies_table
            else:
                continue

            to_table.append(
                f"| {h.group_name} | {description} | "
                f"{primitive_harness.dtypes_to_str(l.dtypes, empty_means_all=True)} | {devices} | {modes} |"
            )

        if not os.environ.get("JAX_OUTPUT_LIMITATIONS_DOC"):
            raise unittest.SkipTest(
                "Set JAX_OUTPUT_LIMITATIONS_DOC=1 to enable the generation of the documentation"
            )
        # The CPU has more supported types, and harnesses
        self.assertEqual("cpu", jtu.device_under_test())
        self.assertTrue(
            config.x64_enabled,
            "Documentation generation must be run with JAX_ENABLE_X64=1")

        with open(
                os.path.join(
                    os.path.dirname(__file__),
                    "../g3doc/primitives_with_limited_support.md.template")
        ) as f:
            template = f.read()
        output_file = os.path.join(
            os.path.dirname(__file__),
            "../g3doc/primitives_with_limited_support.md")

        with open(output_file, "w") as f:
            f.write(template.replace("{{generation_date}}", str(datetime.date.today())) \
                    .replace("{{tf_error_table}}", "\n".join(tf_error_table)) \
                    .replace("{{tf_numerical_discrepancies_table}}", "\n".join(tf_numerical_discrepancies_table)) \
                    )

    # The rest of the test are checking special cases

    @parameterized.named_parameters(
        dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax) for f_jax in [
            jnp.add, jnp.subtract, jnp.multiply, jnp.divide, jnp.less,
            jnp.less_equal, jnp.equal, jnp.greater, jnp.greater_equal,
            jnp.not_equal, jnp.maximum, jnp.minimum
        ])
    def test_type_promotion(self, f_jax=jnp.add):
        # We only test a few types here, as tensorflow does not support many
        # types like uint* or bool in binary ops.
        types = [dtypes.bfloat16, np.int32, np.int64, np.float32]
        for x_dtype in types:
            for y_dtype in types:
                x = np.array([1, 2], dtype=x_dtype)
                y = np.array([3, 4], dtype=y_dtype)
                self.ConvertAndCompare(f_jax, x, y)

    def test_integer_div(self):
        x = jnp.array([-4, -3, -1, 0, 1, 3, 6])
        y = np.int32(3)
        self.ConvertAndCompare(jnp.floor_divide, x, y)
        expected = jnp.floor_divide(x, y)
        # Try it with TF 1 as well (#5831)
        with tf.compat.v1.Session() as sess:
            tf1_res = sess.run(jax2tf.convert(jnp.floor_divide)(x, y))
            self.assertAllClose(expected, tf1_res)

    def test_disable_xla(self):
        def fun(x):
            return lax.pad(x, np.float32(0), [(-1, 0, 0), (0, 0, 0)])

        with self.assertRaisesRegex(
                NotImplementedError,
                "Call to pad cannot be converted with enable_xla=False."):
            self.ConvertAndCompare(fun,
                                   np.ones((2, 3), dtype=np.float32),
                                   enable_xla=False)

    def test_boolean_gather(self):
        values = np.array([[True, True], [False, True], [False, False]],
                          dtype=np.bool_)
        indices = np.array([0, 1], dtype=np.int32)
        for axis in [0, 1]:
            f_jax = jax.jit(lambda v, i: jnp.take(v, i, axis=axis))  # pylint: disable=cell-var-from-loop
            self.ConvertAndCompare(f_jax, values, indices)

    def test_gather_rank_change(self):
        params = jnp.array([[1.0, 1.5, 2.0], [2.0, 2.5, 3.0], [3.0, 3.5, 4.0]])
        indices = jnp.array([[1, 1, 2], [0, 1, 0]])
        f_jax = jax.jit(lambda i: params[i])
        self.ConvertAndCompare(f_jax, indices)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax)
            for f_jax in REDUCE))
    def test_reduce_ops_with_numerical_input(self, f_jax):
        values = np.array([1, 2, 3], dtype=np.float32)
        self.ConvertAndCompare(f_jax, values)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_{op.__name__}", op=op) for op in (
                jax.ops.index_add,
                jax.ops.index_max,
                jax.ops.index_min,
                jax.ops.index_mul,
                jax.ops.index_update,
            )))
    def test_scatter_static(self, op):
        values = np.ones((5, 6), dtype=np.float32)
        update = np.float32(6.)
        f_jax = jax.jit(lambda v, u: op(v, jax.ops.index[::2, 3:], u))
        self.ConvertAndCompare(f_jax, values, update)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax)
            for f_jax in REDUCE))
    def test_reduce_ops_with_boolean_input(self, f_jax):
        values = np.array([True, False, True], dtype=np.bool_)
        self.ConvertAndCompare(f_jax, values)
Ejemplo n.º 28
0
class StaxTest(jtu.JaxTestCase):

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}".format(shape), "shape": shape}
      for shape in [(2, 3), (5,)]))
  def testRandnInitShape(self, shape):
    key = random.PRNGKey(0)
    out = stax.randn()(key, shape)
    self.assertEqual(out.shape, shape)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}".format(shape), "shape": shape}
      for shape in [(2, 3), (2, 3, 4)]))
  def testGlorotInitShape(self, shape):
    key = random.PRNGKey(0)
    out = stax.glorot()(key, shape)
    self.assertEqual(out.shape, shape)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}"
       .format(channels, filter_shape, padding, strides, input_shape),
       "channels": channels, "filter_shape": filter_shape, "padding": padding,
       "strides": strides, "input_shape": input_shape}
      for channels in [2, 3]
      for filter_shape in [(1, 1), (2, 3)]
      for padding in ["SAME", "VALID"]
      for strides in [None, (2, 1)]
      for input_shape in [(2, 10, 11, 1)]))
  def testConvShape(self, channels, filter_shape, padding, strides,
                    input_shape):
    init_fun, apply_fun = stax.Conv(channels, filter_shape, strides=strides,
                                    padding=padding)
    _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}"
       .format(channels, filter_shape, padding, strides, input_shape),
       "channels": channels, "filter_shape": filter_shape, "padding": padding,
       "strides": strides, "input_shape": input_shape}
      for channels in [2, 3]
      for filter_shape in [(1, 1), (2, 3), (3, 3)]
      for padding in ["SAME", "VALID"]
      for strides in [None, (2, 1), (2, 2)]
      for input_shape in [(2, 10, 11, 1)]))
  def testConvTransposeShape(self, channels, filter_shape, padding, strides,
                               input_shape):
    init_fun, apply_fun = stax.ConvTranspose(channels, filter_shape,  # 2D
                                               strides=strides, padding=padding)
    _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}"
       .format(channels, filter_shape, padding, strides, input_shape),
       "channels": channels, "filter_shape": filter_shape, "padding": padding,
       "strides": strides, "input_shape": input_shape}
      for channels in [2, 3]
      for filter_shape in [(1,), (2,), (3,)]
      for padding in ["SAME", "VALID"]
      for strides in [None, (1,), (2,)]
      for input_shape in [(2, 10, 1)]))
  def testConv1DTransposeShape(self, channels, filter_shape, padding, strides,
                               input_shape):
    init_fun, apply_fun = stax.Conv1DTranspose(channels, filter_shape,
                                               strides=strides, padding=padding)
    _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_out_dim={}_input_shape={}"
                        .format(out_dim, input_shape),
       "out_dim": out_dim, "input_shape": input_shape}
      for out_dim in [3, 4]
      for input_shape in [(2, 3), (3, 4)]))
  def testDenseShape(self, out_dim, input_shape):
    init_fun, apply_fun = stax.Dense(out_dim)
    _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_input_shape={}_nonlinear={}"
                        .format(input_shape, nonlinear),
       "input_shape": input_shape, "nonlinear": nonlinear}
      for input_shape in [(2, 3), (2, 3, 4)]
      for nonlinear in ["Relu", "Sigmoid", "Elu", "LeakyRelu"]))
  def testNonlinearShape(self, input_shape, nonlinear):
    init_fun, apply_fun = getattr(stax, nonlinear)
    _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_window_shape={}_padding={}_strides={}_input_shape={}"
                        "_maxpool={}_spec={}"
                        .format(window_shape, padding, strides, input_shape,
                                max_pool, spec),
       "window_shape": window_shape, "padding": padding, "strides": strides,
       "input_shape": input_shape, "max_pool": max_pool, "spec": spec}
      for window_shape in [(1, 1), (2, 3)]
      for padding in ["VALID"]
      for strides in [None, (2, 1)]
      for input_shape in [(2, 5, 6, 4)]
      for max_pool in [False, True]
      for spec in ["NHWC", "NCHW", "WHNC", "WHCN"]))
  def testPoolingShape(self, window_shape, padding, strides, input_shape,
                       max_pool, spec):
    layer = stax.MaxPool if max_pool else stax.AvgPool
    init_fun, apply_fun = layer(window_shape, padding=padding, strides=strides,
                                spec=spec)
    _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}".format(input_shape),
       "input_shape": input_shape}
      for input_shape in [(2, 3), (2, 3, 4)]))
  def testFlattenShape(self, input_shape):
    init_fun, apply_fun = stax.Flatten
    _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_input_shape={}_spec={}".format(input_shape, i),
       "input_shape": input_shape, "spec": spec}
      for input_shape in [(2, 5, 6, 1)]
      for i, spec in enumerate([
          [stax.Conv(3, (2, 2))],
          [stax.Conv(3, (2, 2)), stax.Flatten, stax.Dense(4)]])))
  def testSerialComposeLayersShape(self, input_shape, spec):
    init_fun, apply_fun = stax.serial(*spec)
    _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_input_shape={}".format(input_shape),
       "input_shape": input_shape}
      for input_shape in [(3, 4), (2, 5, 6, 1)]))
  def testDropoutShape(self, input_shape):
    init_fun, apply_fun = stax.Dropout(0.9)
    _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_input_shape={}".format(input_shape),
       "input_shape": input_shape}
      for input_shape in [(3, 4), (2, 5, 6, 1)]))
  def testFanInSum(self, input_shape):
    init_fun, apply_fun = stax.FanInSum
    _CheckShapeAgreement(self, init_fun, apply_fun, [input_shape, input_shape])

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshapes={}_axis={}".format(input_shapes, axis),
       "input_shapes": input_shapes, "axis": axis}
      for input_shapes, axis in [
          ([(2, 3), (2, 1)], 1),
          ([(2, 3), (2, 1)], -1),
          ([(1, 2, 4), (1, 1, 4)], 1),
      ]))
  def testFanInConcat(self, input_shapes, axis):
    init_fun, apply_fun = stax.FanInConcat(axis)
    _CheckShapeAgreement(self, init_fun, apply_fun, input_shapes)

  def testIssue182(self):
    key = random.PRNGKey(0)
    init_fun, apply_fun = stax.Softmax
    input_shape = (10, 3)
    inputs = onp.arange(30.).astype("float32").reshape(input_shape)

    out_shape, params = init_fun(key, input_shape)
    out = apply_fun(params, inputs)

    assert out_shape == out.shape
    assert onp.allclose(onp.sum(onp.asarray(out), -1), 1.)

  def testBatchNormNoScaleOrCenter(self):
    key = random.PRNGKey(0)
    axes = (0, 1, 2)
    init_fun, apply_fun = stax.BatchNorm(axis=axes, center=False, scale=False)
    input_shape = (4, 5, 6, 7)
    inputs = random_inputs(onp.random.RandomState(0), input_shape)

    out_shape, params = init_fun(key, input_shape)
    out = apply_fun(params, inputs)
    means = onp.mean(out, axis=(0, 1, 2))
    std_devs = onp.std(out, axis=(0, 1, 2))
    assert onp.allclose(means, onp.zeros_like(means), atol=1e-4)
    assert onp.allclose(std_devs, onp.ones_like(std_devs), atol=1e-4)

  def testBatchNormShapeNHWC(self):
    key = random.PRNGKey(0)
    init_fun, apply_fun = stax.BatchNorm(axis=(0, 1, 2))
    input_shape = (4, 5, 6, 7)
    inputs = random_inputs(onp.random.RandomState(0), input_shape)

    out_shape, params = init_fun(key, input_shape)
    out = apply_fun(params, inputs)

    self.assertEqual(out_shape, input_shape)
    beta, gamma = params
    self.assertEqual(beta.shape, (7,))
    self.assertEqual(gamma.shape, (7,))
    self.assertEqual(out_shape, out.shape)

  def testBatchNormShapeNCHW(self):
    key = random.PRNGKey(0)
    # Regression test for https://github.com/google/jax/issues/461
    init_fun, apply_fun = stax.BatchNorm(axis=(0, 2, 3))
    input_shape = (4, 5, 6, 7)
    inputs = random_inputs(onp.random.RandomState(0), input_shape)

    out_shape, params = init_fun(key, input_shape)
    out = apply_fun(params, inputs)

    self.assertEqual(out_shape, input_shape)
    beta, gamma = params
    self.assertEqual(beta.shape, (5,))
    self.assertEqual(gamma.shape, (5,))
    self.assertEqual(out_shape, out.shape)

  def testAttentionShape(self):
    batch_size = 7
    nhead = 2
    query_dim, value_dim, out_dim = (nhead * 3, nhead * 5, 29)
    query_src_dim, key_src_dim, value_src_dim = (11, 13, 17)
    kv_len, q_len = (19, 23)

    input_shape = [(batch_size, q_len, query_src_dim),
                   (batch_size, kv_len, key_src_dim),
                   (batch_size, kv_len, value_src_dim),
                   (batch_size, kv_len)]

    init_fun, apply_fun = stax.GeneralAttention(
        out_dim, nhead=nhead, query_dim=query_dim, value_dim=value_dim)

    _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

  def testAttentionBehavior(self):
    key = random.PRNGKey(0)

    batch_size = 7
    nhead = 2
    query_dim, value_dim, out_dim = (nhead * 3, nhead * 3, nhead * 3)
    query_src_dim, key_src_dim, value_src_dim = (out_dim, out_dim, out_dim)
    kv_len, q_len = (19, 19)  # Use same lengths for testing causal mask

    input_shape = [(batch_size, q_len, query_src_dim),
                   (batch_size, kv_len, key_src_dim),
                   (batch_size, kv_len, value_src_dim),
                   (batch_size, kv_len)]

    init_fun, apply_fun = stax.GeneralAttention(
        out_dim, nhead=nhead, query_dim=query_dim, value_dim=value_dim,
        W_init=identity_initializer, b_init=nn.initializers.zeros)

    unused_output_shape, params = init_fun(key, input_shape)
    query, key, value = [onp.zeros(shape) for shape in input_shape[:3]]
    mask = onp.ones(input_shape[3])
    output = apply_fun(params, (query, key, value, mask))

    # Attention to zero values with zero keys must be zero
    self.assertAllClose(output, onp.zeros(output.shape), check_dtypes=True)

    # Make score(K[0, 4], Q[0, 3]) and score(K[0, 4], Q[0, 8]) to be a huge
    #  value for all heads
    for n in range(nhead):
      head_offset = out_dim // nhead * n
      key[0, 4, 0 + head_offset] = 100.0
      query[0, 3, 0 + head_offset] = 100.0
      query[0, 8, 0 + head_offset] = 100.0

    rand_vec = onp.random.RandomState(0).randn(out_dim)
    value[0, 4] = rand_vec

    # With identity initializers and the above key and query,
    #  output[0, 3] must be close enough to value[0, 4].
    output = apply_fun(params, (query, key, value, mask))
    expected = onp.zeros(output.shape)
    expected[0, :, :] = rand_vec / kv_len
    expected[0, 3, :] = rand_vec
    expected[0, 8, :] = rand_vec
    self.assertAllClose(output, expected, check_dtypes=True)

    def causal_mask(shape):
      mask = onp.cumsum(onp.identity(shape[2]), axis=0)
      return mask

    init_fun, apply_fun = stax.GeneralAttention(
        out_dim, nhead=nhead, query_dim=query_dim, value_dim=value_dim,
        att_prob_mask_fun=causal_mask,
        W_init=identity_initializer, b_init=nn.initializers.zeros)

    output = apply_fun(params, (query, key, value, mask))

    # Check that it cannot attend to the future value.
    self.assertAllClose(output[0, 3, :], onp.zeros(out_dim), check_dtypes=True)
    # But, for 8th query, 3rd value must be retrieved.
    self.assertAllClose(output[0, 8, :], rand_vec, check_dtypes=True)
Ejemplo n.º 29
0
class LaxBackedScipyTests(jtu.JaxTestCase):
    """Tests for LAX-backed Scipy implementation."""
    def _GetArgsMaker(self, rng, shapes, dtypes):
        return lambda: [
            rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)
        ]

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inshape={}_axis={}_keepdims={}".format(
                jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
            "rng":
            jtu.rand_default(),
            "shape":
            shape,
            "dtype":
            dtype,
            "axis":
            axis,
            "keepdims":
            keepdims
        } for shape in all_shapes for dtype in float_dtypes
                            for axis in range(-len(shape), len(shape))
                            for keepdims in [False, True]))
    @jtu.skip_on_flag("jax_xla_backend", "xrt")
    def testLogSumExp(self, rng, shape, dtype, axis, keepdims):
        # TODO(mattjj): test autodiff
        def scipy_fun(array_to_reduce):
            return osp_special.logsumexp(array_to_reduce,
                                         axis,
                                         keepdims=keepdims)

        def lax_fun(array_to_reduce):
            return lsp_special.logsumexp(array_to_reduce,
                                         axis,
                                         keepdims=keepdims)

        args_maker = lambda: [rng(shape, dtype)]
        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                jtu.format_test_name_suffix(rec.test_name, shapes, dtypes),
                "rng":
                rec.rng,
                "shapes":
                shapes,
                "dtypes":
                dtypes,
                "test_autodiff":
                rec.test_autodiff,
                "scipy_op":
                getattr(osp_special, rec.name),
                "lax_op":
                getattr(lsp_special, rec.name)
            } for rec in JAX_SPECIAL_FUNCTION_RECORDS
            for shapes in CombosWithReplacement(all_shapes, rec.nargs)
            for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs)))
    def testScipySpecialFun(self, scipy_op, lax_op, rng, shapes, dtypes,
                            test_autodiff):
        args_maker = self._GetArgsMaker(rng, shapes, dtypes)
        args = args_maker()
        self.assertAllClose(scipy_op(*args),
                            lax_op(*args),
                            atol=1e-3,
                            rtol=1e-3,
                            check_dtypes=False)
        self._CompileAndCheck(lax_op, args_maker, check_dtypes=True)

        if test_autodiff:
            jtu.check_grads(lax_op,
                            args,
                            order=1,
                            atol=1e-3,
                            rtol=3e-3,
                            eps=1e-3)

    def testIssue980(self):
        x = onp.full((4, ), -1e20, dtype=onp.float32)
        self.assertAllClose(onp.zeros((4, ), dtype=onp.float32),
                            lsp_special.expit(x),
                            check_dtypes=True)
Ejemplo n.º 30
0
class TestLBFGS(jtu.JaxTestCase):

  @parameterized.named_parameters(jtu.cases_from_list(
    {"testcase_name": "_func={}_maxiter={}".format(func_and_init[0].__name__, maxiter),
     "maxiter": maxiter, "func_and_init": func_and_init}
    for maxiter in [None]
    for func_and_init in [(rosenbrock, np.zeros(2)),
                          (himmelblau, np.zeros(2)),
                          (matyas, np.ones(2) * 6.),
                          (eggholder, np.ones(2) * 100.)]))
  def test_minimize(self, maxiter, func_and_init):

    func, x0 = func_and_init

    @jit
    def min_op(x0):
      result = jax.scipy.optimize.minimize(
          func(jnp),
          x0,
          method='l-bfgs-experimental-do-not-rely-on-this',
          options=dict(maxiter=maxiter, gtol=1e-7),
      )
      return result.x

    jax_res = min_op(x0)

    # Note that without bounds, L-BFGS-B is just L-BFGS
    with jtu.ignore_warning(category=DeprecationWarning,
                            message=".*tostring.*is deprecated.*"):
      scipy_res = scipy.optimize.minimize(func(np), x0, method='L-BFGS-B').x

    if func.__name__ == 'matyas':
      # scipy performs badly for Matyas, compare to true minimum instead
      self.assertAllClose(jax_res, jnp.zeros_like(jax_res), atol=1e-7)
      return

    if func.__name__ == 'eggholder':
      # L-BFGS performs poorly for the eggholder function.
      # Neither scipy nor jax find the true minimum, so we can only loosely (with high atol) compare the false results
      self.assertAllClose(jax_res, scipy_res, atol=1e-3)
      return

    self.assertAllClose(jax_res, scipy_res, atol=2e-5, check_dtypes=False)

  def test_minimize_complex_sphere(self):
    z0 = jnp.array([1., 2. - 3.j, 4., -5.j])

    def f(z):
      return jnp.real(jnp.dot(jnp.conj(z - z0), z - z0))

    @jit
    def min_op(x0):
      result = jax.scipy.optimize.minimize(
          f,
          x0,
          method='l-bfgs-experimental-do-not-rely-on-this',
          options=dict(gtol=1e-6),
      )
      return result.x

    jax_res = min_op(jnp.zeros_like(z0))

    self.assertAllClose(jax_res, z0)

  def test_complex_rosenbrock(self):
    complex_dim = 5

    f_re = rosenbrock(jnp)
    init_re = jnp.zeros((2 * complex_dim,))
    expect_re = jnp.ones((2 * complex_dim,))

    def f(z):
      x_re = jnp.concatenate([jnp.real(z), jnp.imag(z)])
      return f_re(x_re)

    init = init_re[:complex_dim] + 1.j * init_re[complex_dim:]
    expect = expect_re[:complex_dim] + 1.j * expect_re[complex_dim:]

    @jit
    def min_op(z0):
      result = jax.scipy.optimize.minimize(
          f,
          z0,
          method='l-bfgs-experimental-do-not-rely-on-this',
          options=dict(gtol=1e-6),
      )
      return result.x

    jax_res = min_op(init)
    self.assertAllClose(jax_res, expect, atol=2e-5)