示例#1
0
    def testMultiBackend(self, backend):
        if backend not in ('cpu', jtu.device_under_test(), None):
            raise SkipTest("Backend is not CPU or the device under test")

        @partial(jax.jit, backend=backend)
        def fun(x, y):
            return jnp.matmul(x, y)

        x = npr.uniform(size=(10, 10))
        y = npr.uniform(size=(10, 10))
        z_host = np.matmul(x, y)
        z = fun(x, y)
        self.assertAllClose(z, z_host, rtol=1e-2)
        correct_platform = backend if backend else jtu.device_under_test()
        self.assertEqual(z.device_buffer.platform(), correct_platform)
示例#2
0
  def test_custom_linear_solve_cholesky(self):

    def positive_definite_solve(a, b):
      factors = jsp.linalg.cho_factor(a)
      def solve(matvec, x):
        return jsp.linalg.cho_solve(factors, x)
      matvec = partial(high_precision_dot, a)
      return lax.custom_linear_solve(matvec, b, solve, symmetric=True)

    rng = self.rng()
    a = rng.randn(2, 2)
    b = rng.randn(2)

    tol = {np.float32: 1E-3 if jtu.device_under_test() == "tpu" else 1E-5,
           np.float64: 1E-12}
    expected = jnp.linalg.solve(np.asarray(posify(a)), b)
    actual = positive_definite_solve(posify(a), b)
    self.assertAllClose(expected, actual, rtol=tol, atol=tol)

    actual = jax.jit(positive_definite_solve)(posify(a), b)
    self.assertAllClose(expected, actual, rtol=tol, atol=tol)

    # numerical gradients are only well defined if ``a`` is guaranteed to be
    # positive definite.
    jtu.check_grads(
        lambda x, y: positive_definite_solve(posify(x), y),
        (a, b), order=2, rtol=0.3)
示例#3
0
  def test_ragged_batched_rnn(self):
    n = 3

    @partial(mask, in_shapes=('(_, _)', '(t, _)', '_'), out_shape='')
    def rnn(W, xs, target):
      def step(h, x):
        new_h = jnp.tanh(jnp.dot(W, h) + jnp.dot(W, x))
        return new_h, ()
      predicted, _ = lax.scan(step, jnp.zeros(n), xs)
      return jnp.sum((predicted - target)**2)

    rng = self.rng()
    W = rng.randn(n, n).astype(jnp.float_)
    seqs = rng.randn(3, 10, n).astype(jnp.float_)
    ts = jnp.array([2, 5, 4])
    ys = rng.randn(3, n)

    ans = grad(lambda W: vmap(rnn, ((None, 0, 0), 0))((W, seqs, ys), dict(t=ts)).sum())(W)

    def rnn_reference(W, seqs, targets):
      total_loss = jnp.array(0.0)
      for xs, target in zip(seqs, targets):
        h = jnp.zeros(n)
        for x in xs:
          h = jnp.tanh(jnp.dot(W, h) + jnp.dot(W, x))
        predicted = h
        total_loss = total_loss + jnp.sum((predicted - target)**2)
      return total_loss

    seqs_ = [xs[:t] for xs, t in zip(seqs, ts)]
    expected = grad(lambda W: rnn_reference(W, seqs_, ys).sum())(W)

    self.assertAllClose(
      ans, expected, check_dtypes=False,
      rtol=0.1 if jtu.device_under_test() == "tpu" else 1e-5)
示例#4
0
  def testLogSumExp(self, shapes, dtype, axis,
                    keepdims, return_sign, use_b):
    if jtu.device_under_test() != "cpu":
      rng = jtu.rand_some_inf_and_nan(self.rng())
    else:
      rng = jtu.rand_default(self.rng())
    # TODO(mattjj): test autodiff
    if use_b:
      def scipy_fun(array_to_reduce, scale_array):
        return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
                                     return_sign=return_sign, b=scale_array)

      def lax_fun(array_to_reduce, scale_array):
        return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
                                     return_sign=return_sign, b=scale_array)

      args_maker = lambda: [rng(shapes[0], dtype), rng(shapes[1], dtype)]
    else:
      def scipy_fun(array_to_reduce):
        return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
                                     return_sign=return_sign)

      def lax_fun(array_to_reduce):
        return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
                                     return_sign=return_sign)

      args_maker = lambda: [rng(shapes[0], dtype)]
    tol = {np.float32: 1E-6, np.float64: 1E-14}
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker)
    self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)
示例#5
0
    def test_function_dynamic_shape(self):
        if jtu.device_under_test() == "tpu":
            raise unittest.SkipTest("TODO: why does this fail on TPU?")
        # Call a function for which shape inference does not give an output
        # shape.
        x = np.array([-1, 0, 1], dtype=np.int32)

        def fun_tf(x):  # x:i32[3]
            # The shape depends on the value of x
            res = tf.where(x >= 0)
            return res

        # Call in eager mode. Should work!
        res1 = jax2tf.call_tf(fun_tf)(x)
        expected = np.array([[1], [2]])
        self.assertAllClose(expected, res1, check_dtypes=False)

        # Now under jit, should fail because the function is not compileable
        with self.assertRaisesRegex(
                ValueError,
                "Error compiling TensorFlow function. call_tf can used in a staged context"
        ):
            fun_jax = jax.jit(jax2tf.call_tf(fun_tf))
            fun_jax(x)

        # TODO(necula): this should work in op-by-op mode, but it fails because
        # jax2tf.convert does abstract evaluation.
        with self.assertRaisesRegex(
                ValueError,
                "Error compiling TensorFlow function. call_tf can used in a staged context"
        ):
            fun_tf_rt = jax2tf.convert(jax2tf.call_tf(fun_tf))
            fun_tf_rt(x)
示例#6
0
    def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes,
                            dtypes, test_autodiff, nondiff_argnums):
        if (jtu.device_under_test() == "cpu"
                and (lax_op is lsp_special.gammainc
                     or lax_op is lsp_special.gammaincc)):
            # TODO(b/173608403): re-enable test when LLVM bug is fixed.
            raise unittest.SkipTest("Skipping test due to LLVM lowering bug")
        rng = rng_factory(self.rng())
        args_maker = self._GetArgsMaker(rng, shapes, dtypes)
        args = args_maker()
        self.assertAllClose(scipy_op(*args),
                            lax_op(*args),
                            atol=1e-3,
                            rtol=1e-3,
                            check_dtypes=False)
        self._CompileAndCheck(lax_op, args_maker, rtol=1e-4)

        if test_autodiff:

            def partial_lax_op(*vals):
                list_args = list(vals)
                for i in nondiff_argnums:
                    list_args.insert(i, args[i])
                return lax_op(*list_args)

            assert list(nondiff_argnums) == sorted(set(nondiff_argnums))
            diff_args = [
                x for i, x in enumerate(args) if i not in nondiff_argnums
            ]
            jtu.check_grads(partial_lax_op,
                            diff_args,
                            order=1,
                            atol=jtu.if_device_under_test("tpu", .1, 1e-3),
                            rtol=.1,
                            eps=1e-3)
示例#7
0
    def test_jax_implemented(self, harness: primitive_harness.Harness):
        """Runs all harnesses just with JAX to verify the jax_unimplemented field.
    """
        jax_unimpl = [
            l for l in harness.jax_unimplemented
            if l.filter(device=jtu.device_under_test(), dtype=harness.dtype)
        ]
        if any([lim.skip_run for lim in jax_unimpl]):
            logging.info(
                "Skipping run with expected JAX limitations: %s in harness %s",
                [u.description for u in jax_unimpl], harness.fullname)
            return
        try:
            harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
        except Exception as e:
            if jax_unimpl:
                logging.info(
                    "Found expected JAX error %s with expected JAX limitations: "
                    "%s in harness %s", e, [u.description for u in jax_unimpl],
                    harness.fullname)
                return
            else:
                raise e

        if jax_unimpl:
            logging.warning(
                "Found no JAX error but expected JAX limitations: %s in "
                "harness: %s", [u.description for u in jax_unimpl],
                harness.fullname)
示例#8
0
    def testDetrend(self, shape, dtype, axis, type, bp):
        signal = np.random.normal(loc=2, size=shape)

        if type == 'constant':
            trend = np.ones_like(signal)
        elif type == 'linear':
            trend = np.linspace(-np.pi, np.pi, shape[0])
            if len(shape) == 1:
                trend = np.broadcast_to(trend, shape)
            elif len(shape) == 2:
                trend = np.broadcast_to(trend[:, None], shape)
            elif len(shape) == 3:
                trend = np.broadcast_to(trend[:, None, None], shape)

        args_maker = lambda: [signal, trend]

        def osp_fun(signal, trend):
            return osp_signal.detrend(
                signal + trend, axis=axis, type=type, bp=bp) - trend

        def jsp_fun(signal, noise):
            return jsp_signal.detrend(
                signal + noise, axis=axis, type=type, bp=bp) - trend

        if jtu.device_under_test() == 'tpu':
            tol = {np.float32: 3e-2, np.float64: 1e-12}
        else:
            tol = {np.float32: 1e-5, np.float64: 1e-12}

        self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
示例#9
0
    def testWelchWithDefaultStepArgsAgainstNumpy(self, *, shape, dtype,
                                                 nperseg, noverlap,
                                                 use_nperseg, use_noverlap,
                                                 timeaxis):
        kwargs = {'axis': timeaxis}

        if use_nperseg:
            kwargs['nperseg'] = nperseg
        else:
            kwargs['window'] = osp_signal.get_window('hann', nperseg)
        if use_noverlap:
            kwargs['noverlap'] = noverlap

        osp_fun = partial(osp_signal.welch, **kwargs)
        jsp_fun = partial(jsp_signal.welch, **kwargs)
        tol = {
            np.float32: 1e-5,
            np.float64: 1e-12,
            np.complex64: 1e-5,
            np.complex128: 1e-12
        }
        if jtu.device_under_test() == 'tpu':
            tol = _TPU_FFT_TOL

        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(shape, dtype)]

        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                rtol=tol,
                                atol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
示例#10
0
 def setUp(self):
     super().setUp()
     if jtu.device_under_test() != "tpu":
         raise SkipTest("serialize executable only works on TPU")
     if jax._src.lib.xla_bridge.get_backend().runtime_type == "tfrt":
         raise SkipTest(
             "the new TFRT runtime does not support serialization")
示例#11
0
    def test_spectral_dac_svd(self, linear_size, seed, dtype):
        if jnp.dtype(dtype).name in ("bfloat16", "float16"):
            if jtu.device_under_test() != "cpu":
                raise unittest.SkipTest("Skip half precision off CPU.")

        rng = self.rng()
        A = rng.randn(linear_size, linear_size).astype(dtype)
        if jnp.dtype(dtype).name in ("bfloat16", "float16"):
            self.assertRaises(NotImplementedError, jax._src.scipy.eigh.svd, A)
            return
        S_expected = np.linalg.svd(A, compute_uv=False)
        U, S, V = jax._src.scipy.eigh.svd(A)
        recon = jnp.dot((U * jnp.expand_dims(S, 0)),
                        V,
                        precision=lax.Precision.HIGHEST)
        eps = jnp.finfo(dtype).eps
        eps = eps * jnp.linalg.norm(A) * 15
        self.assertAllClose(np.sort(S), np.sort(S_expected), atol=eps)
        self.assertAllClose(A, recon, atol=eps)

        # U is unitary.
        u_unitary_delta = jnp.dot(U.conj().T,
                                  U,
                                  precision=lax.Precision.HIGHEST)
        u_eye = jnp.eye(u_unitary_delta.shape[0], dtype=dtype)
        self.assertAllClose(u_unitary_delta, u_eye, atol=eps)

        # V is unitary.
        v_unitary_delta = jnp.dot(V.conj().T,
                                  V,
                                  precision=lax.Precision.HIGHEST)
        v_eye = jnp.eye(v_unitary_delta.shape[0], dtype=dtype)
        self.assertAllClose(v_unitary_delta, v_eye, atol=eps)
  def testTensorFlowToJax(self, shape, dtype):
    if not config.x64_enabled and dtype in [jnp.int64, jnp.uint64, jnp.float64]:
      raise self.skipTest("x64 types are disabled by jax_enable_x64")
    if (jtu.device_under_test() == "gpu" and
        not tf.config.list_physical_devices("GPU")):
      raise self.skipTest("TensorFlow not configured with GPU support")

    if jtu.device_under_test() == "gpu" and dtype == jnp.int32:
      raise self.skipTest("TensorFlow does not place int32 tensors on GPU")

    rng = jtu.rand_default(self.rng())
    np = rng(shape, dtype)
    with tf.device("/GPU:0" if jtu.device_under_test() == "gpu" else "/CPU:0"):
      x = tf.identity(tf.constant(np))
    dlpack = tf.experimental.dlpack.to_dlpack(x)
    y = jax.dlpack.from_dlpack(dlpack)
    self.assertAllClose(np, y)
示例#13
0
 def test_eval_numpy_no_copy(self):
   if jtu.device_under_test() != "cpu":
     raise unittest.SkipTest("no_copy test works only on CPU")
   # For ndarray, zero-copy only works for sufficiently-aligned arrays.
   x = np.ones((16, 16), dtype=np.float32)
   res = jax2tf.call_tf(lambda x: x)(x)
   self.assertAllClose(x, res)
   self.assertTrue(np.shares_memory(x, res))
示例#14
0
 def test_eval_devicearray_no_copy(self):
   if jtu.device_under_test() != "cpu":
     # TODO(necula): add tests for GPU and TPU
     raise unittest.SkipTest("no_copy test works only on CPU")
   # For DeviceArray zero-copy works even if not aligned
   x = jnp.ones((3, 3))
   res = jax2tf.call_tf(lambda x: x)(x)
   self.assertAllClose(x, res)
   self.assertTrue(np.shares_memory(x, res))
示例#15
0
    def testPolar(self, n_zero_sv, degeneracy, geometric_spectrum, max_sv,
                  shape, method, side, nonzero_condition_number, dtype, seed):
        """ Tests jax.scipy.linalg.polar."""
        if jtu.device_under_test() != "cpu":
            if jnp.dtype(dtype).name in ("bfloat16", "float16"):
                raise unittest.SkipTest("Skip half precision off CPU.")

        m, n = shape
        if (method == "qdwh" and ((side == "left" and m >= n) or
                                  (side == "right" and m < n))):
            raise unittest.SkipTest("method=qdwh does not support these sizes")

        matrix, _ = _initialize_polar_test(self.rng(), shape, n_zero_sv,
                                           degeneracy, geometric_spectrum,
                                           max_sv, nonzero_condition_number,
                                           dtype)
        if jnp.dtype(dtype).name in ("bfloat16", "float16"):
            self.assertRaises(NotImplementedError,
                              jsp.linalg.polar,
                              matrix,
                              method=method,
                              side=side)
            return

        unitary, posdef = jsp.linalg.polar(matrix, method=method, side=side)
        if shape[0] >= shape[1]:
            should_be_eye = np.matmul(unitary.conj().T, unitary)
        else:
            should_be_eye = np.matmul(unitary, unitary.conj().T)
        tol = 500 * float(jnp.finfo(matrix.dtype).eps)
        eye_mat = np.eye(should_be_eye.shape[0], dtype=should_be_eye.dtype)
        with self.subTest('Test unitarity.'):
            self.assertAllClose(eye_mat, should_be_eye, atol=tol * min(shape))

        with self.subTest('Test Hermiticity.'):
            self.assertAllClose(posdef,
                                posdef.conj().T,
                                atol=tol * jnp.linalg.norm(posdef))

        ev, _ = np.linalg.eigh(posdef)
        ev = ev[np.abs(ev) > tol * np.linalg.norm(posdef)]
        negative_ev = jnp.sum(ev < 0.)
        with self.subTest('Test positive definiteness.'):
            self.assertEqual(negative_ev, 0)

        if side == "right":
            recon = jnp.matmul(unitary,
                               posdef,
                               precision=lax.Precision.HIGHEST)
        elif side == "left":
            recon = jnp.matmul(posdef,
                               unitary,
                               precision=lax.Precision.HIGHEST)
        with self.subTest('Test reconstruction.'):
            self.assertAllClose(matrix,
                                recon,
                                atol=tol * jnp.linalg.norm(matrix))
示例#16
0
    def testMultiBackendNestedJitConflict(self, ordering):
        outer, inner = ordering
        if outer not in ('cpu', jtu.device_under_test(), None):
            raise SkipTest("Backend is not CPU or the device under test")
        if inner not in ('cpu', jtu.device_under_test(), None):
            raise SkipTest("Backend is not CPU or the device under test")

        @partial(jax.jit, backend=outer)
        def fun(x, y):
            @partial(jax.jit, backend=inner)
            def infun(x, y):
                return jnp.matmul(x, y)

            return infun(x, y) + jnp.ones_like(x)

        x = npr.uniform(size=(10, 10))
        y = npr.uniform(size=(10, 10))
        self.assertRaises(ValueError, lambda: fun(x, y))
示例#17
0
    def testCsdWithSameParamAgainstNumpy(self, *, shape, dtype, fs, window,
                                         nperseg, noverlap, nfft, detrend,
                                         scaling, timeaxis, average):
        is_complex = np.dtype(dtype).kind == 'c'
        if is_complex and detrend is not None:
            raise unittest.SkipTest(
                "Complex signal is not supported in lax-backed `signal.detrend`."
            )

        def osp_fun(x, y):
            # When the identical parameters are given, jsp-version follows
            # the behavior with copied parameters.
            freqs, Pxy = osp_signal.csd(x,
                                        y.copy(),
                                        fs=fs,
                                        window=window,
                                        nperseg=nperseg,
                                        noverlap=noverlap,
                                        nfft=nfft,
                                        detrend=detrend,
                                        return_onesided=not is_complex,
                                        scaling=scaling,
                                        axis=timeaxis,
                                        average=average)
            return freqs, Pxy

        jsp_fun = partial(jsp_signal.csd,
                          fs=fs,
                          window=window,
                          nperseg=nperseg,
                          noverlap=noverlap,
                          nfft=nfft,
                          detrend=detrend,
                          return_onesided=not is_complex,
                          scaling=scaling,
                          axis=timeaxis,
                          average=average)

        tol = {
            np.float32: 1e-5,
            np.float64: 1e-12,
            np.complex64: 1e-5,
            np.complex128: 1e-12
        }
        if jtu.device_under_test() == 'tpu':
            tol = _TPU_FFT_TOL

        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(shape, dtype)] * 2

        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                rtol=tol,
                                atol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
 def testTorchToJax(self, shape, dtype):
   if not config.x64_enabled and dtype in [jnp.int64, jnp.float64]:
     self.skipTest("x64 types are disabled by jax_enable_x64")
   rng = jtu.rand_default(self.rng())
   np = rng(shape, dtype)
   x = torch.from_numpy(np)
   x = x.cuda() if jtu.device_under_test() == "gpu" else x
   dlpack = torch.utils.dlpack.to_dlpack(x)
   y = jax.dlpack.from_dlpack(dlpack)
   self.assertAllClose(np, y)
示例#19
0
    def testIstftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
                              noverlap, nfft, onesided, boundary, timeaxis,
                              freqaxis):
        if not onesided:
            new_freq_len = (shape[freqaxis] - 1) * 2
            shape = shape[:freqaxis] + (new_freq_len, ) + shape[freqaxis + 1:]

        def osp_fun(x, fs):
            # Ignore UserWarning in osp so we can also test over ill-posed cases.
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                result = osp_signal.istft(x,
                                          fs=fs,
                                          window=window,
                                          nperseg=nperseg,
                                          noverlap=noverlap,
                                          nfft=nfft,
                                          input_onesided=onesided,
                                          boundary=boundary,
                                          time_axis=timeaxis,
                                          freq_axis=freqaxis)
            return result

        jsp_fun = partial(jsp_signal.istft,
                          window=window,
                          nperseg=nperseg,
                          noverlap=noverlap,
                          nfft=nfft,
                          input_onesided=onesided,
                          boundary=boundary,
                          time_axis=timeaxis,
                          freq_axis=freqaxis)

        tol = {
            np.float32: 1e-4,
            np.float64: 1e-6,
            np.complex64: 1e-4,
            np.complex128: 1e-6
        }
        if jtu.device_under_test() == 'tpu':
            tol = _TPU_FFT_TOL

        rng = jtu.rand_default(self.rng())
        rng_fs = jtu.rand_uniform(self.rng(), 1.0, 16000.0)
        args_maker = lambda: [rng(shape, dtype), rng_fs((), np.float)]

        # Here, dtype of output signal is different depending on osp versions,
        # and so depending on the test environment.  Thus, dtype check is disabled.
        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                rtol=tol,
                                atol=tol,
                                check_dtypes=False)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
示例#20
0
    def testMultiBackendNestedJit(self, ordering):
        outer, inner = ordering
        if outer not in ('cpu', jtu.device_under_test(), None):
            raise SkipTest("Backend is not CPU or the device under test")

        @partial(jax.jit, backend=outer)
        def fun(x, y):
            @partial(jax.jit, backend=inner)
            def infun(x, y):
                return jnp.matmul(x, y)

            return infun(x, y) + jnp.ones_like(x)

        x = npr.uniform(size=(10, 10))
        y = npr.uniform(size=(10, 10))
        z_host = np.matmul(x, y) + np.ones_like(x)
        z = fun(x, y)
        self.assertAllClose(z, z_host, rtol=1e-2)
        correct_platform = outer if outer else jtu.device_under_test()
        self.assertEqual(z.device_buffer.platform(), correct_platform)
示例#21
0
    def test_jit_cache(self):
        if jtu.device_under_test() == "tpu":
            self.skipTest("64-bit random not available on TPU")

        f = partial(random.uniform, random.PRNGKey(0), (1, ), 'float64', -1, 1)
        with disable_x64():
            for _ in range(2):
                f()
        with enable_x64():
            for _ in range(2):
                f()
示例#22
0
 def test_prim(self, harness: primitive_harness.Harness):
   limitations = Jax2TfLimitation.limitations_for_harness(harness)
   device = jtu.device_under_test()
   limitations = tuple(filter(lambda l: l.filter(device=device,
                                                 dtype=harness.dtype), limitations))
   func_jax = harness.dyn_fun
   args = harness.dyn_args_maker(self.rng())
   enable_xla = harness.params.get("enable_xla", True)
   associative_scan_reductions = harness.params.get("associative_scan_reductions", False)
   with jax.jax2tf_associative_scan_reductions(associative_scan_reductions):
     self.ConvertAndCompare(func_jax, *args, limitations=limitations,
                            enable_xla=enable_xla)
示例#23
0
  def test_with_var_read_x64(self, with_jit=True):
    if jtu.device_under_test() == "gpu":
      raise unittest.SkipTest("Test fails on GPU")
    outer_var_array = np.array([3., 4.], dtype=np.float64)
    outer_var = tf.Variable(outer_var_array)

    def fun_tf(x):
      return x * tf.cast(outer_var, x.dtype) + 1.

    x = np.array([2., 5.,], dtype=np.float32)
    res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
    self.assertAllClose(x * outer_var_array + 1., res, check_dtypes=False)
示例#24
0
    def testWelchAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
                              noverlap, nfft, detrend, return_onesided,
                              scaling, timeaxis, average):
        if np.dtype(dtype).kind == 'c':
            return_onesided = False
            if detrend is not None:
                raise unittest.SkipTest(
                    "Complex signal is not supported in lax-backed `signal.detrend`."
                )

        osp_fun = partial(osp_signal.welch,
                          fs=fs,
                          window=window,
                          nperseg=nperseg,
                          noverlap=noverlap,
                          nfft=nfft,
                          detrend=detrend,
                          return_onesided=return_onesided,
                          scaling=scaling,
                          axis=timeaxis,
                          average=average)
        jsp_fun = partial(jsp_signal.welch,
                          fs=fs,
                          window=window,
                          nperseg=nperseg,
                          noverlap=noverlap,
                          nfft=nfft,
                          detrend=detrend,
                          return_onesided=return_onesided,
                          scaling=scaling,
                          axis=timeaxis,
                          average=average)
        tol = {
            np.float32: 1e-5,
            np.float64: 1e-12,
            np.complex64: 1e-5,
            np.complex128: 1e-12
        }
        if jtu.device_under_test() == 'tpu':
            tol = _TPU_FFT_TOL

        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(shape, dtype)]

        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                rtol=tol,
                                atol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
示例#25
0
  def test_with_multiple_capture(self, with_jit=True):
    if jtu.device_under_test() == "gpu":
      raise unittest.SkipTest("Test fails on GPU")
    v2 = tf.Variable(2., dtype=np.float32)
    v3 = tf.Variable(3., dtype=np.float32)
    t4 = tf.constant(4., dtype=np.float32)
    t5 = tf.constant(5., dtype=np.float32)

    def fun_tf(x):
      return (x * v3 + t4 + v2) * v3 + t5

    x = np.float32(2.)
    res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x)
    self.assertAllClose((x * 3. + 4. + 2.) * 3. + 5., res, check_dtypes=False)
示例#26
0
    def testGpuMultiBackendOpByOpReturn(self, backend):
        if backend not in ('cpu', jtu.device_under_test()):
            raise SkipTest("Backend is not CPU or the device under test")

        @partial(jax.jit, backend=backend)
        def fun(x, y):
            return jnp.matmul(x, y)

        x = npr.uniform(size=(10, 10))
        y = npr.uniform(size=(10, 10))
        z = fun(x, y)
        w = jnp.sin(z)
        self.assertEqual(z.device_buffer.platform(), backend)
        self.assertEqual(w.device_buffer.platform(), backend)
示例#27
0
  def test_with_var_different_shape(self):
    # See https://github.com/google/jax/issues/6050
    if jtu.device_under_test() == "gpu":
      raise unittest.SkipTest("Test fails on GPU")
    v = tf.Variable((4., 2.), dtype=tf.float32)

    def tf_func(x):
      return v + x
    x = np.float32(123.)
    tf_out = tf_func(x)

    jax_func = jax.jit(jax2tf.call_tf(tf_func))
    jax_out = jax_func(x)

    self.assertAllClose(tf_out, jax_out, check_dtypes=False)
示例#28
0
    def test_spectral_dac_eigh(self, linear_size, seed, dtype,
                               termination_size):
        if jnp.dtype(dtype).name in ("bfloat16", "float16"):
            if jtu.device_under_test() != "cpu":
                raise unittest.SkipTest("Skip half precision off CPU.")

        if jtu.device_under_test() != "tpu" and termination_size != 1:
            raise unittest.SkipTest(
                "Termination sizes greater than 1 only work on TPU")

        rng = self.rng()
        H = rng.randn(linear_size, linear_size)
        H = jnp.array(0.5 * (H + H.conj().T)).astype(dtype)
        if jnp.dtype(dtype).name in ("bfloat16", "float16"):
            self.assertRaises(NotImplementedError, jax._src.lax.eigh.eigh, H)
            return
        evs, V = jax._src.lax.eigh.eigh(H, termination_size=termination_size)
        ev_exp, eV_exp = jnp.linalg.eigh(H)
        HV = jnp.dot(H, V, precision=lax.Precision.HIGHEST)
        vV = evs[None, :] * V
        eps = jnp.finfo(H.dtype).eps
        atol = jnp.linalg.norm(H) * eps
        self.assertAllClose(ev_exp, jnp.sort(evs), atol=20 * atol)
        self.assertAllClose(HV, vV, atol=30 * atol)
示例#29
0
    def setUp(self):
        super().setUp()
        # Ensure that all TF ops are created on the proper device (TPU or GPU or CPU)
        tf_preferred_devices = (tf.config.list_logical_devices("TPU") +
                                tf.config.list_logical_devices("GPU") +
                                tf.config.list_logical_devices())
        self.tf_default_device = tf_preferred_devices[0]
        logging.info("Running jax2tf converted code on %s.",
                     self.tf_default_device)
        # We need --config=cuda build flag for TF to see the GPUs
        self.assertEqual(jtu.device_under_test().upper(),
                         self.tf_default_device.device_type)

        with contextlib.ExitStack() as stack:
            stack.enter_context(tf.device(self.tf_default_device))
            self.addCleanup(stack.pop_all().close)
示例#30
0
    def testProgrammaticProfilingContextManager(self):
        with tempfile.TemporaryDirectory() as tmpdir:
            with jax.profiler.trace(tmpdir):
                jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'),
                         axis_name='i')(jnp.ones(jax.local_device_count()))

            proto_path = glob.glob(os.path.join(tmpdir, "**/*.xplane.pb"),
                                   recursive=True)
            self.assertEqual(len(proto_path), 1)
            with open(proto_path[0], "rb") as f:
                proto = f.read()
            # Sanity check that serialized proto contains host and device traces
            # without deserializing.
            self.assertIn(b"/host:CPU", proto)
            if jtu.device_under_test() == "tpu":
                self.assertIn(b"/device:TPU", proto)