Exemplo n.º 1
0
    def test_std_basis(self):
        basis = api._std_basis(np.zeros(3))
        assert getattr(basis, "shape", None) == (3, 3)
        assert onp.allclose(basis, onp.eye(3))

        basis = api._std_basis(np.zeros((3, 3)))
        assert getattr(basis, "shape", None) == (9, 3, 3)
        assert onp.allclose(basis, onp.eye(9).reshape(9, 3, 3))

        basis = api._std_basis([0., (np.zeros(3), np.zeros((3, 4)))])
        assert isinstance(basis, list) and len(basis) == 2
        assert getattr(basis[0], "shape", None) == (16, )
        assert isinstance(basis[1], tuple) and len(basis[1]) == 2
        assert getattr(basis[1][0], "shape", None) == (16, 3)
        assert getattr(basis[1][1], "shape", None) == (16, 3, 4)
Exemplo n.º 2
0
        def get_ntk(x1, x2, *args):
            args1, args2 = args[:len(args) // 2], args[len(args) // 2:]
            _kwargs1 = {k: v for k, v in zip(keys, args1)}
            _kwargs2 = {k: v for k, v in zip(keys, args2)}

            f1 = _get_f_params(f, x1, x_axis, fx_axis, kw_axes, **_kwargs1)
            f2 = f1 if utils.all_none(x2) else _get_f_params(
                f, x2, x_axis, fx_axis, kw_axes, **_kwargs2)

            def delta_vjp_jvp(delta):
                def delta_vjp(delta):
                    return vjp(f2, params)[1](delta)

                return jvp(f1, (params, ), delta_vjp(delta))[1]

            fx1, fx2 = eval_shape(f1, params), eval_shape(f2, params)
            eye = _std_basis(fx1)
            ntk = vmap(linear_transpose(delta_vjp_jvp, fx2))(eye)
            ntk = tree_map(
                lambda fx12: _unravel_array_into_pytree(fx1, 0, fx12), ntk)
            ntk = _diagonal(ntk, fx1)
            return ntk