Example #1
0
    def ntk_fn(x1: NTTree[np.ndarray], x2: Optional[NTTree[np.ndarray]],
               params: PyTree, **apply_fn_kwargs) -> np.ndarray:
        """Computes a single sample of the empirical NTK (implicit differentiation).

    Args:
      x1:
        first batch of inputs.
      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.
      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.
      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split
        into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs`
        function which will be passed to `apply_fn`. In particular, the rng key
        in `apply_fn_kwargs`, will be split into two different (if `x1 != x2`)
        or same (if `x1 == x2`) rng keys. See the `_read_key` function for more
        details.

    Returns:
      A single sample of the empirical NTK. The shape of the kernel is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.
    """
        kwargs1, kwargs2 = utils.split_kwargs(apply_fn_kwargs, x1, x2)
        f1 = _flatten(_get_f_params(f, x1, **kwargs1))
        f2 = (f1 if utils.all_none(x2) else _flatten(
            _get_f_params(f, x2, **kwargs2)))

        def delta_vjp_jvp(delta):
            def delta_vjp(delta):
                return vjp(f2, params)[1](delta)

            return jvp(f1, (params, ), delta_vjp(delta))[1]

        # Since we are taking the Jacobian of a linear function (which does not
        # depend on its coefficients), it is more efficient to substitute fx_dummy
        # for the outputs of the network. fx_dummy has the same shape as the output
        # of the network on a single piece of input data.
        fx2_struct = eval_shape(f2, params)

        @utils.nt_tree_fn()
        def dummy_output(fx_struct):
            return np.ones(fx_struct.shape, fx_struct.dtype)

        fx_dummy = dummy_output(fx2_struct)

        ntk = jacobian(delta_vjp_jvp)(fx_dummy)
        if utils.is_list_or_tuple(fx_dummy):
            fx_treedef = tree_structure(
                eval_shape(_get_f_params(f, x1, **kwargs1), params))
            ntk = [ntk[i][i] for i in range(len(fx_dummy))]
            ntk = tree_unflatten(fx_treedef, ntk)

        return _trace_and_diagonal(ntk, trace_axes, diagonal_axes)
Example #2
0
    def ntk_fn(x1: NTTree[np.ndarray], x2: Optional[NTTree[np.ndarray]],
               params: PyTree, **apply_fn_kwargs) -> np.ndarray:
        """Computes a single sample of the empirical NTK (jacobian outer product).

    Args:
      x1:
        first batch of inputs.
      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.
      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.
      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split
        into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs`
        function which will be passed to `apply_fn`. In particular, the rng key
        in `apply_fn_kwargs`, will be split into two different (if `x1!=x2`) or
        same (if `x1==x2`) rng keys. See the `_read_key` function for more
        details.

    Returns:
      A single sample of the empirical NTK. The shape of the kernel is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.
    """
        kwargs1, kwargs2 = utils.split_kwargs(apply_fn_kwargs, x1, x2)
        fx1 = eval_shape(f, params, x1, **kwargs1)
        x_axis, fx_axis, kw_axes = _canonicalize_axes(vmap_axes, x1, fx1,
                                                      **kwargs1)

        keys = apply_fn_kwargs.keys()
        args1, args2 = (kwargs1[k] for k in keys), (kwargs2[k] for k in keys)

        def j_fn(x, *args):
            _kwargs = {k: v for k, v in zip(keys, args)}
            fx = _get_f_params(f, x, x_axis, fx_axis, kw_axes, **_kwargs)
            jx = jacobian(fx)(params)
            return jx

        if x_axis is not None or kw_axes:
            in_axes = [x_axis] + [
                kw_axes[k] if k in kw_axes else None for k in keys
            ]
            j_fn = vmap(j_fn, in_axes=in_axes, out_axes=fx_axis)

        j1 = j_fn(x1, *args1)
        j2 = j_fn(x2, *args2) if not utils.all_none(x2) else j1
        ntk = sum_and_contract(fx1, j1, j2)
        return ntk
Example #3
0
    def nngp_fn(x1: np.ndarray, x2: Optional[np.ndarray], params: PyTree,
                **apply_fn_kwargs) -> np.ndarray:
        """Computes a single sample of the empirical NNGP.

    Args:
      x1:
        first batch of inputs.
      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.
      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.
      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split
        into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs`
        function which will be passed to `apply_fn`. In particular, the rng key
        in `apply_fn_kwargs`, will be split into two different (if `x1!=x2`) or
        same (if `x1==x2`) rng keys. See the `_read_key` function for more
        details.

    Returns:
      A single sample of the empirical NNGP. The shape of the kernel is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.
    """
        def output(x, **kwargs):
            out = f(params, x, **kwargs)
            masked_output = utils.get_masked_array(out)
            return utils.nt_tree_fn()(lambda x: x.masked_value)(masked_output)

        kwargs1, kwargs2 = utils.split_kwargs(apply_fn_kwargs, x1, x2)

        out1 = output(x1, **kwargs1)
        if utils.all_none(x2):
            out2 = out1
        else:
            out2 = output(x2, **kwargs2)

        @utils.nt_tree_fn()
        def contract(out1, out2):
            dot = utils.dot_general(out1, out2, trace_axes, diagonal_axes)
            return dot / utils.size_at(out1, trace_axes)

        return contract(out1, out2)
Example #4
0
    def parallel_fn_x1(x1, x2=None, *args, **kwargs):
        x2_is_none = utils.all_none(x2)
        if x2_is_none:
            # TODO(schsam): Only compute the upper triangular part of the kernel.
            x2 = x1

        def get_batch_size(x):
            if utils.is_list_or_tuple(x):
                return get_batch_size(x[0])
            return x.shape[0]

        n1 = get_batch_size(x1)
        n2 = n1 if x2_is_none else get_batch_size(x2)

        _check_dropout(n1, n2, kwargs)
        n1_per_device, _device_count = _get_n_per_device(n1, n2)

        _kernel_fn = _jit_or_pmap_broadcast(kernel_fn, _device_count)

        @utils.nt_tree_fn()
        def batch_data(x):
            input_shape = x.shape[1:]
            return np.reshape(x, (
                _device_count,
                n1_per_device,
            ) + input_shape)

        for k, v in kwargs.items():
            if _is_np_ndarray(v):
                assert isinstance(v, tuple) and len(v) == 2
                v0 = np.reshape(v[0], (
                    _device_count,
                    n1_per_device,
                ) + v[0].shape[1:])
                kwargs[k] = (v0, v[1])

        x1 = batch_data(x1)

        kernel = _kernel_fn(x1, x2, *args, **kwargs)
        return _flatten_kernel(kernel, x2_is_none, True)
Example #5
0
        def get_ntk(x1, x2, *args):
            args1, args2 = args[:len(args) // 2], args[len(args) // 2:]
            _kwargs1 = {k: v for k, v in zip(keys, args1)}
            _kwargs2 = {k: v for k, v in zip(keys, args2)}

            f1 = _get_f_params(f, x1, x_axis, fx_axis, kw_axes, **_kwargs1)
            f2 = f1 if utils.all_none(x2) else _get_f_params(
                f, x2, x_axis, fx_axis, kw_axes, **_kwargs2)

            def delta_vjp_jvp(delta):
                def delta_vjp(delta):
                    return vjp(f2, params)[1](delta)

                return jvp(f1, (params, ), delta_vjp(delta))[1]

            fx1, fx2 = eval_shape(f1, params), eval_shape(f2, params)
            eye = _std_basis(fx1)
            ntk = vmap(linear_transpose(delta_vjp_jvp, fx2))(eye)
            ntk = tree_map(
                lambda fx12: _unravel_array_into_pytree(fx1, 0, fx12), ntk)
            ntk = _diagonal(ntk, fx1)
            return ntk
Example #6
0
    def ntk_fn(x1: NTTree[np.ndarray], x2: Optional[NTTree[np.ndarray]],
               params: PyTree, **apply_fn_kwargs) -> np.ndarray:
        """Computes a single sample of the empirical NTK (implicit differentiation).

    Args:
      x1:
        first batch of inputs.
      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.
      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.
      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split
        into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs`
        function which will be passed to `apply_fn`. In particular, the rng key
        in `apply_fn_kwargs`, will be split into two different (if `x1 != x2`)
        or same (if `x1 == x2`) rng keys. See the `_read_key` function for more
        details.

    Returns:
      A single sample of the empirical NTK. The shape of the kernel is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.
    """
        kwargs1, kwargs2 = utils.split_kwargs(apply_fn_kwargs, x1, x2)
        fx1 = eval_shape(f, params, x1, **kwargs1)
        x_axis, fx_axis, kw_axes = _canonicalize_axes(vmap_axes, x1, fx1,
                                                      **kwargs1)

        keys = apply_fn_kwargs.keys()
        args1 = (kwargs1[k] for k in keys)
        args2 = (kwargs1[k]
                 if k in kw_axes and kwargs2[k] is None else kwargs2[k]
                 for k in keys)

        def get_ntk(x1, x2, *args):
            args1, args2 = args[:len(args) // 2], args[len(args) // 2:]
            _kwargs1 = {k: v for k, v in zip(keys, args1)}
            _kwargs2 = {k: v for k, v in zip(keys, args2)}

            f1 = _get_f_params(f, x1, x_axis, fx_axis, kw_axes, **_kwargs1)
            f2 = f1 if utils.all_none(x2) else _get_f_params(
                f, x2, x_axis, fx_axis, kw_axes, **_kwargs2)

            def delta_vjp_jvp(delta):
                def delta_vjp(delta):
                    return vjp(f2, params)[1](delta)

                return jvp(f1, (params, ), delta_vjp(delta))[1]

            fx1, fx2 = eval_shape(f1, params), eval_shape(f2, params)
            eye = _std_basis(fx1)
            ntk = vmap(linear_transpose(delta_vjp_jvp, fx2))(eye)
            ntk = tree_map(
                lambda fx12: _unravel_array_into_pytree(fx1, 0, fx12), ntk)
            ntk = _diagonal(ntk, fx1)
            return ntk

        if x_axis is not None or kw_axes:
            x2 = x1 if utils.all_none(x2) else x2

            kw_in_axes = [kw_axes[k] if k in kw_axes else None for k in keys]
            in_axes1 = [x_axis, None] + kw_in_axes + [None] * len(kw_in_axes)
            in_axes2 = [None, x_axis] + [None] * len(kw_in_axes) + kw_in_axes

            get_ntk = vmap(vmap(get_ntk, in_axes1, fx_axis), in_axes2,
                           _add(fx_axis, _ndim(fx1)))

        return _trace_and_diagonal(get_ntk(x1, x2, *args1, *args2), trace_axes,
                                   diagonal_axes)
Example #7
0
    def serial_fn_x1(x1: NTTree[np.ndarray],
                     x2: Optional[NTTree[Optional[np.ndarray]]] = None,
                     *args,
                     **kwargs) -> NTTree[Kernel]:

        x2_is_none = utils.all_none(x2)
        if x2_is_none:
            # TODO(schsam): Only compute the upper triangular part of the kernel.
            x2 = x1

        @utils.nt_tree_fn(reduce=lambda x: x[0])
        def get_n1_n2(x1, x2):
            n1, n2 = x1.shape[0], x2.shape[0]
            return n1, n2

        n1, n2 = get_n1_n2(x1, x2)

        (n1_batches, n1_batch_size, n2_batches, n2_batch_size) = \
            _get_n_batches_and_batch_sizes(n1, n2, batch_size, device_count)

        @utils.nt_tree_fn(nargs=1)
        def batch_input(x, batch_count, batch_size):
            input_shape = x.shape[1:]
            return np.reshape(x, (
                batch_count,
                batch_size,
            ) + input_shape)

        x1s = batch_input(x1, n1_batches, n1_batch_size)
        x2s = batch_input(x2, n2_batches, n2_batch_size)

        kwargs_np1 = {}
        kwargs_np2 = {}
        kwargs_other = {}

        for k, v in kwargs.items():
            if _is_np_ndarray(v):
                if k == 'rng':
                    key1, key2 = random.split(v)
                    v1 = random.split(key1, n1_batches)
                    v2 = random.split(key2, n2_batches)
                else:
                    assert isinstance(v, tuple) and len(v) == 2
                    v1 = np.reshape(v[0], (
                        n1_batches,
                        n1_batch_size,
                    ) + v[0].shape[1:])
                    v2 = np.reshape(v[1], (
                        n2_batches,
                        n2_batch_size,
                    ) + v[1].shape[1:])
                kwargs_np1[k] = v1
                kwargs_np2[k] = v2
            else:
                kwargs_other[k] = v

        def row_fn(_, x1):
            return _, _scan(col_fn, x1, (x2s, kwargs_np2))[1]

        def col_fn(x1, x2):
            x1, kwargs1 = x1
            x2, kwargs2 = x2
            kwargs_merge = {
                **kwargs_other,
                **dict((k, (kwargs1[k], kwargs2[k])) for k in kwargs1)
            }
            return (x1, kwargs1), kernel_fn(x1, x2, *args, **kwargs_merge)

        _, kernel = _scan(row_fn, 0, (x1s, kwargs_np1))
        return flatten(kernel, x2_is_none)