Esempio n. 1
0
def _rfft_transpose(t, fft_lengths):
    # The transpose of RFFT can't be expressed only in terms of irfft. Instead of
    # manually building up larger twiddle matrices (which would increase the
    # asymptotic complexity and is also rather complicated), we rely JAX to
    # transpose a naive RFFT implementation.
    dummy_shape = t.shape[:-len(fft_lengths)] + fft_lengths
    dummy_primal = ShapeDtypeStruct(dummy_shape, _real_dtype(t.dtype))
    transpose = linear_transpose(partial(_naive_rfft, fft_lengths=fft_lengths),
                                 dummy_primal)
    result, = transpose(t)
    assert result.dtype == _real_dtype(t.dtype), (result.dtype, t.dtype)
    return result
Esempio n. 2
0
        def get_ntk(x1, x2, *args):
            args1, args2 = args[:len(args) // 2], args[len(args) // 2:]
            _kwargs1 = {k: v for k, v in zip(keys, args1)}
            _kwargs2 = {k: v for k, v in zip(keys, args2)}

            f1 = _get_f_params(f, x1, x_axis, fx_axis, kw_axes, **_kwargs1)
            f2 = f1 if utils.all_none(x2) else _get_f_params(
                f, x2, x_axis, fx_axis, kw_axes, **_kwargs2)

            def delta_vjp_jvp(delta):
                def delta_vjp(delta):
                    return vjp(f2, params)[1](delta)

                return jvp(f1, (params, ), delta_vjp(delta))[1]

            fx1, fx2 = eval_shape(f1, params), eval_shape(f2, params)
            eye = _std_basis(fx1)
            ntk = vmap(linear_transpose(delta_vjp_jvp, fx2))(eye)
            ntk = tree_map(
                lambda fx12: _unravel_array_into_pytree(fx1, 0, fx12), ntk)
            ntk = _diagonal(ntk, fx1)
            return ntk
Esempio n. 3
0
 def test_linear_transpose_imag(self):
   f = lambda x: x.imag
   transpose = api.linear_transpose(f, 1.j)
   actual, = transpose(1.)
   expected = -1.j
   self.assertEqual(actual, expected)
Esempio n. 4
0
 def test_linear_transpose_real(self):
   f = lambda x: x.real
   transpose = api.linear_transpose(f, 1.j)
   actual, = transpose(1.)
   expected = 1.
   self.assertEqual(actual, expected)