def _rfft_transpose(t, fft_lengths): # The transpose of RFFT can't be expressed only in terms of irfft. Instead of # manually building up larger twiddle matrices (which would increase the # asymptotic complexity and is also rather complicated), we rely JAX to # transpose a naive RFFT implementation. dummy_shape = t.shape[:-len(fft_lengths)] + fft_lengths dummy_primal = ShapeDtypeStruct(dummy_shape, _real_dtype(t.dtype)) transpose = linear_transpose(partial(_naive_rfft, fft_lengths=fft_lengths), dummy_primal) result, = transpose(t) assert result.dtype == _real_dtype(t.dtype), (result.dtype, t.dtype) return result
def get_ntk(x1, x2, *args): args1, args2 = args[:len(args) // 2], args[len(args) // 2:] _kwargs1 = {k: v for k, v in zip(keys, args1)} _kwargs2 = {k: v for k, v in zip(keys, args2)} f1 = _get_f_params(f, x1, x_axis, fx_axis, kw_axes, **_kwargs1) f2 = f1 if utils.all_none(x2) else _get_f_params( f, x2, x_axis, fx_axis, kw_axes, **_kwargs2) def delta_vjp_jvp(delta): def delta_vjp(delta): return vjp(f2, params)[1](delta) return jvp(f1, (params, ), delta_vjp(delta))[1] fx1, fx2 = eval_shape(f1, params), eval_shape(f2, params) eye = _std_basis(fx1) ntk = vmap(linear_transpose(delta_vjp_jvp, fx2))(eye) ntk = tree_map( lambda fx12: _unravel_array_into_pytree(fx1, 0, fx12), ntk) ntk = _diagonal(ntk, fx1) return ntk
def test_linear_transpose_imag(self): f = lambda x: x.imag transpose = api.linear_transpose(f, 1.j) actual, = transpose(1.) expected = -1.j self.assertEqual(actual, expected)
def test_linear_transpose_real(self): f = lambda x: x.real transpose = api.linear_transpose(f, 1.j) actual, = transpose(1.) expected = 1. self.assertEqual(actual, expected)