예제 #1
0
    def ker_fun(kernels):
        """Compute kernels."""
        var1, nngp, var2, ntk, is_gaussian, _ = kernels
        if nngp.ndim == 2:
            return kernels

        var1 = np.mean(var1, axis=(1, 2))
        var2 = var2 if var2 is None else np.mean(var2, axis=(1, 2))

        if nngp.ndim == 4:
            nngp = np.mean(nngp, axis=(2, 3))
            if _is_array(ntk):
                ntk = np.mean(ntk, axis=(2, 3))
            return Kernel(var1, nngp, var2, ntk, is_gaussian, True)

        if nngp.ndim == 6:

            def trace(x):
                count = x.shape[2] * x.shape[4]
                y = np.trace(x, axis1=4, axis2=5)
                z = np.trace(y, axis1=2, axis2=3)
                return z / count

            nngp = trace(nngp)
            if _is_array(ntk):
                ntk = trace(ntk)
            return Kernel(var1, nngp, var2, ntk, is_gaussian, True)

        raise ValueError('`nngp` array must be 2d or 6d.')
예제 #2
0
  def parallel_fn_kernel(kernel, *args, **kwargs):
    n1 = kernel.cov1.shape[0]

    _device_count = device_count

    n1_per_device, ragged = divmod(n1, device_count)
    if n1_per_device and ragged:
      raise ValueError(
          ('Dataset size ({}) must divide number of '
           'physical devices ({}).').format(n1, device_count))
    elif not n1_per_device:
      _device_count = ragged
      n1_per_device = 1

    kernel_dict = kernel._asdict()

    cov2 = kernel_dict['cov2']
    cov2_is_none = cov2 is None
    if cov2 is None:
      cov2 = kernel_dict['cov1']
    kernel_dict['cov2'] = np.broadcast_to(cov2, (_device_count,) + cov2.shape)
    kernel_dict['x1_is_x2'] = np.broadcast_to(
        kernel_dict['x1_is_x2'],
        (_device_count,) + kernel_dict['x1_is_x2'].shape)

    for k, v in kernel_dict.items():
      if k in ('nngp', 'ntk', 'cov1'):
        kernel_dict[k] = \
            np.reshape(v, (_device_count, n1_per_device,) + v.shape[1:])
      if k in ('shape1',):
        kernel_dict[k] = (n1_per_device,) + v[1:]
    kernel = kernel_fn(Kernel(**kernel_dict), *args, **kwargs)
    if cov2_is_none:
      kernel = kernel._replace(cov2=None)
    return _flatten_kernel(kernel, cov2_is_none, True)
예제 #3
0
  def serial_fn_kernel(k: Kernel, *args, **kwargs) -> Kernel:
    n1, n2 = k.nngp.shape[:2]
    (n1_batches, n1_batch_size,
     n2_batches, n2_batch_size) = _get_n_batches_and_batch_sizes(n1, n2,
                                                                 batch_size,
                                                                 device_count)

    n1s = np.arange(0, n1, n1_batch_size)
    n2s = np.arange(0, n2, n2_batch_size)

    def row_fn(_, n1: int) -> Tuple[int, Kernel]:
      return _, _scan(col_fn, n1, n2s)[1]

    def col_fn(n1: int, n2: int) -> Tuple[int, Kernel]:
      # NOTE(schsam): If we end up wanting to enable jit-of-batch then we will
      # probably have to change this to dynamic slicing.
      n1_slice = slice(n1, n1 + n1_batch_size)
      n2_slice = slice(n2, n2 + n2_batch_size)
      in_kernel = k.slice(n1_slice, n2_slice)
      return n1, kernel_fn(in_kernel, *args, **kwargs)

    cov2_is_none = k.cov2 is None
    _, k = _scan(row_fn, 0, n1s)
    if cov2_is_none:
      k = k.replace(cov2=None)
    return flatten(k, cov2_is_none)
예제 #4
0
def _reshape_kernel_for_pmap(k: Kernel,
                             device_count: int,
                             n1_per_device: int) -> Kernel:
  cov2 = k.cov2
  if cov2 is None:
    cov2 = k.cov1
  cov2 = np.broadcast_to(cov2, (device_count,) + cov2.shape)

  mask2 = k.mask2
  if mask2 is None and k.mask1 is not None:
    mask2 = k.mask1
  if mask2 is not None:
    mask2 = np.broadcast_to(mask2, (device_count,) + mask2.shape)

  x1_is_x2 = np.broadcast_to(k.x1_is_x2, (device_count,) + k.x1_is_x2.shape)

  nngp, ntk, cov1 = [
      np.reshape(x, (device_count, n1_per_device,) + x.shape[1:]) for x in
      (k.nngp, k.ntk, k.cov1)]

  return k.replace(
      nngp=nngp,
      ntk=ntk,
      cov1=cov1,
      cov2=cov2,
      x1_is_x2=x1_is_x2,
      shape1=(n1_per_device,) + k.shape1[1:],
      mask2=mask2)
예제 #5
0
def _transform_kernels_ab_relu(kernels, a, b, do_backprop, do_stabilize):
    """Compute new kernels after an `ABRelu` layer.

  See https://arxiv.org/pdf/1711.09090.pdf for the leaky ReLU derivation.
  """
    var1, nngp, var2, ntk, _, is_height_width = kernels

    if do_stabilize:
        factor = np.max([np.max(np.abs(nngp)), 1e-12])
        nngp /= factor
        var1 /= factor
        if var2 is not None:
            var2 /= factor

    prod = _get_var_prod(var1, nngp, var2)
    cosines = nngp / np.sqrt(prod)
    angles = _arccos(cosines, do_backprop)
    dot_sigma = (a**2 + b**2 - (a - b)**2 * angles / np.pi) / 2
    if ntk is not None:
        ntk *= dot_sigma

    nngp = ((a - b)**2 * _sqrt(prod - nngp**2, do_backprop) / (2 * np.pi) +
            dot_sigma * nngp)
    if do_stabilize:
        nngp *= factor

    var1 *= (a**2 + b**2) / 2
    if var2 is not None:
        var2 *= (a**2 + b**2) / 2

    if do_stabilize:
        var1 *= factor
        var2 *= factor

    return Kernel(var1, nngp, var2, ntk, a == b, is_height_width)
예제 #6
0
    def ker_fun(kernels):
        is_gaussian = all(ker.is_gaussian for ker in kernels)
        if not is_gaussian:
            raise NotImplementedError(
                '`FanInSum` layer is only implemented for the '
                'case if all input layers guaranteed to be mean'
                '-zero gaussian, i.e. having all `is_gaussian'
                'set to `True`.')

        # If kernels have different height/width order, transpose some of them.
        n_kernels = len(kernels)
        n_height_width = sum(ker.is_height_width for ker in kernels)

        if n_height_width == n_kernels:
            is_height_width = True

        elif n_height_width >= n_kernels / 2:
            is_height_width = True
            for i in range(n_kernels):
                if not kernels[i].is_height_width:
                    kernels[i] = _flip_height_width(kernels[i])

        else:
            is_height_width = False
            for i in range(n_kernels):
                if kernels[i].is_height_width:
                    kernels[i] = _flip_height_width(kernels[i])

        kers = tuple(None if all(ker[i] is None for ker in kernels) else sum(
            ker[i] for ker in kernels) for i in range(4))
        return Kernel(*(kers + (is_gaussian, is_height_width)))
예제 #7
0
    def ker_fun(kernels):
        """Kernel transformation."""
        var1, nngp, var2, ntk, is_gaussian, is_height_width = kernels

        if not is_height_width:
            window_shape_nngp = window_shape[::-1]
            strides_nngp = strides[::-1]
        else:
            window_shape_nngp = window_shape
            strides_nngp = strides

        nngp = _average_pool_nngp_6d(nngp, window_shape_nngp, strides_nngp,
                                     padding)
        ntk = _average_pool_nngp_6d(ntk, window_shape_nngp, strides_nngp,
                                    padding)

        if var2 is None:
            var1 = _diagonal_nngp_6d(nngp)
        else:
            # TODO(romann)
            warnings.warn(
                'Pooling for different inputs `x1` and `x2` is not '
                'implemented and will only work if there are no '
                'nonlinearities in the network anywhere after the pooling '
                'layer. `var1` and `var2` will have wrong values. '
                'This will be fixed soon.')

        return Kernel(var1, nngp, var2, ntk, is_gaussian, is_height_width)
예제 #8
0
def _move_kernel_to_cpu(kernel):
    """Moves data in a kernel from an accelerator to the CPU."""
    if isinstance(kernel, Kernel):
        return Kernel(device_get(kernel.var1), device_get(kernel.nngp),
                      device_get(kernel.var2), device_get(kernel.ntk),
                      kernel.is_gaussian, kernel.is_height_width)
    else:
        return device_get(kernel)
예제 #9
0
def _flatten_kernel(k: Kernel, x2_is_none: bool, is_parallel: bool) -> Kernel:
    """Flattens a kernel array or a `Kernel` along the batch dimension."""

    # pytype: disable=attribute-error
    if hasattr(k, '_asdict'):
        return k._replace(
            **_flatten_kernel_dict(k._asdict(), x2_is_none, is_parallel))

    elif isinstance(k, Kernel):
        return Kernel(
            **_flatten_kernel_dict(k.asdict(), x2_is_none, is_parallel))
    # pytype:enable=attribute-error

    elif isinstance(k, np.ndarray):
        return _flatten_batch_dimensions(k, is_parallel)

    raise TypeError(f'Expected kernel to be either a namedtuple, `Kernel`, or '
                    f'`np.ndarray`, got {type(k)}.')
예제 #10
0
 def _f(_x_or_kernel_np, *_args_np):
   # Merge Kernel.
   if is_input_kernel:
     _x_or_kernel_np = {**_x_or_kernel_np, **x_or_kernel_other}
     _x_or_kernel_np = Kernel(**_x_or_kernel_np)
   # Merge args.
   _args_np = {i: _arg_np for i, _arg_np in zip(args_np_idxs, _args_np)}
   _args = {**_args_np, **args_other}
   _args = tuple(v for k, v in sorted(_args.items()))
   return f(_x_or_kernel_np, *_args, **kwargs)
예제 #11
0
    def ker_fun(kernels):
        """Compute the transformed kernels after a conv layer."""
        # var1: batch_1 * height * width
        # var2: batch_2 * height * width
        # nngp, ntk: batch_1 * batch_2 * height * height * width * width (pooling)
        #  or batch_1 * batch_2 * height * width (flattening)
        var1, nngp, var2, ntk, _, is_height_width = kernels

        if nngp.ndim == 4:

            def conv_var(x):
                x = _conv_var_3d(x, filter_shape, strides, padding)
                x = _affine(x, W_std, b_std)
                return x

            def conv_nngp(x):
                if _is_array(x):
                    x = _conv_nngp_4d(x, filter_shape, strides, padding)
                x = _affine(x, W_std, b_std)
                return x

        elif nngp.ndim == 6:
            if not is_height_width:
                filter_shape_nngp = filter_shape[::-1]
                strides_nngp = strides[::-1]
            else:
                filter_shape_nngp = filter_shape
                strides_nngp = strides

            def conv_var(x):
                x = _conv_var_3d(x, filter_shape_nngp, strides_nngp, padding)
                if x is not None:
                    x = np.transpose(x, (0, 2, 1))
                x = _affine(x, W_std, b_std)
                return x

            def conv_nngp(x):
                if _is_array(x):
                    x = _conv_nngp_6d_double_conv(x, filter_shape_nngp,
                                                  strides_nngp, padding)
                x = _affine(x, W_std, b_std)
                return x

            is_height_width = not is_height_width

        else:
            raise ValueError('`nngp` array must be either 4d or 6d, got %d.' %
                             nngp.ndim)

        var1 = conv_var(var1)
        var2 = conv_var(var2)
        nngp = conv_nngp(nngp)
        ntk = conv_nngp(ntk) + nngp - b_std**2 if ntk is not None else ntk
        return Kernel(var1, nngp, var2, ntk, True, is_height_width)
예제 #12
0
    def new_ker_fun(x1_or_kernel,
                    x2=None,
                    compute_nngp=True,
                    compute_ntk=True):
        """Returns the `Kernel` resulting from applying `ker_fun` to given inputs.

    Inputs can be either a pair of `np.ndarray`s, or a `Kernel'. If `n_samples`
      is positive, `ker_fun` is estimated by Monte Carlo sampling of random
      networks defined by `(init_fun, apply_fun)`.

    Args:
      x1_or_kernel: either a `np.ndarray` with shape
        `[batch_size_1] + input_shape`, or a `Kernel`.
      x2: an optional `np.ndarray` with shape `[batch_size_2] + input_shape`.
        `None` means `x2 == x1` or `x1_or_kernel is Kernel`.
      compute_nngp: a boolean, `True` to compute NNGP kernel.
      compute_ntk: a boolean, `True` to compute NTK kernel.

    Returns:
      A `Kernel`.
    """
        if (isinstance(x1_or_kernel, Kernel)
                or (isinstance(x1_or_kernel, list)
                    and all(isinstance(k, Kernel) for k in x1_or_kernel))):
            kernel = x1_or_kernel

        elif isinstance(x1_or_kernel, np.ndarray):
            if x2 is None or isinstance(x2, np.ndarray):
                if not compute_nngp:
                    if compute_ntk:
                        raise ValueError(
                            'NNGP has to be computed to compute NTK. Please '
                            'set `compute_nngp=True`.')
                    else:
                        return Kernel(None, None, None, None, None, None)

                use_pooling = getattr(ker_fun, _USE_POOLING, True)
                kernel = _inputs_to_kernel(x1_or_kernel, x2, use_pooling,
                                           compute_ntk)

            else:
                raise TypeError('`x2` to a kernel propagation function '
                                'should be `None` or a `np.ndarray`, got %s.' %
                                type(x2))

        else:
            raise TypeError(
                'Inputs to a kernel propagation function should be '
                'a `Kernel`, '
                'a `list` of `Kernel`s, '
                'or a (tuple of) `np.ndarray`(s), got %s.' %
                type(x1_or_kernel))

        return ker_fun(kernel)
예제 #13
0
    def ker_fun(kernels):
        """Compute the transformed kernels after a dense layer."""
        var1, nngp, var2, ntk, _, _ = kernels

        def fc(x):
            return _affine(x, W_std, b_std)

        var1, nngp, var2, ntk = map(fc, (var1, nngp, var2, ntk))
        if ntk is not None:
            ntk += nngp - b_std**2

        return Kernel(var1, nngp, var2, ntk, True, True)
예제 #14
0
    def serial_fn_kernel(k: Kernel, *args, **kwargs) -> Kernel:
        n1, n2 = k.nngp.shape[:2]
        (n1_batches, n1_batch_size, n2_batches,
         n2_batch_size) = _get_n_batches_and_batch_sizes(
             n1, n2, batch_size, device_count)

        n1s = np.arange(0, n1, n1_batch_size)
        n2s = np.arange(0, n2, n2_batch_size)

        kwargs_np1 = {}
        kwargs_np2 = {}

        kwargs_other = {}
        for key, v in kwargs.items():
            if _is_np_ndarray(v):
                assert isinstance(v, tuple) and len(v) == 2
                v1 = np.reshape(v[0], (
                    n1_batches,
                    n1_batch_size,
                ) + v[0].shape[1:])
                v2 = np.reshape(v[1], (
                    n2_batches,
                    n2_batch_size,
                ) + v[1].shape[1:])
                kwargs_np1[key] = v1
                kwargs_np2[key] = v2
            else:
                kwargs_other[key] = v

        def row_fn(_, n1):
            return _, _scan(col_fn, n1, (n2s, kwargs_np2))[1]

        def col_fn(n1, n2):
            # NOTE(schsam): If we end up wanting to enable jit-of-batch then we will
            # probably have to change this to dynamic slicing.
            n1, kwargs1 = n1
            n2, kwargs2 = n2
            kwargs_merge = {
                **kwargs_other,
                **dict((key, (kwargs1[key], kwargs2[key])) for key in kwargs1)
            }
            n1_slice = slice(n1, n1 + n1_batch_size)
            n2_slice = slice(n2, n2 + n2_batch_size)
            in_kernel = k.slice(n1_slice, n2_slice)
            return (n1, kwargs1), kernel_fn(in_kernel, *args, **kwargs_merge)

        cov2_is_none = k.cov2 is None
        _, k = _scan(row_fn, 0, (n1s, kwargs_np1))
        if cov2_is_none:
            k = k.replace(cov2=None)
        return flatten(k, cov2_is_none)
예제 #15
0
def _flatten_kernel(k, x2_is_none, store_on_device, is_parallel):
  """Flattens a kernel array or a `Kernel` along the batch dimension."""
  if hasattr(k, '_asdict'):
    return k._replace(**_flatten_kernel_dict(
        k._asdict(), x2_is_none, store_on_device, is_parallel))

  if isinstance(k, Kernel):
    return Kernel(**_flatten_kernel_dict(
        k.asdict(), x2_is_none, store_on_device, is_parallel))

  if isinstance(k, np.ndarray):
    return _flatten_batch_dimensions(k)

  raise TypeError(
      ('Expected kernel to be either a namedtuple, Kernel, or `np.ndarray`, '
       'got %s.') % type(k))
예제 #16
0
def _flatten_kernel(k):
    """Flattens a kernel array or a `Kernel` along the batch dimension."""
    if isinstance(k, Kernel):
        return Kernel(
            _flatten_batch_dimensions(k.var1, discard_axis=1),
            _flatten_batch_dimensions(k.nngp),
            _flatten_batch_dimensions(k.var2, discard_axis=0),
            _flatten_batch_dimensions(k.ntk),
            np.all(k.is_gaussian) if k.is_gaussian is not None else None,
            np.all(k.is_height_width)
            if k.is_height_width is not None else None)
    elif isinstance(k, np.ndarray):
        return _flatten_batch_dimensions(k)
    else:
        raise TypeError(
            'Expected kernel to be either a `Kernel` or a `np.ndarray`, got %s.'
            % type(k))
예제 #17
0
def _transform_kernels_erf(kernels, do_backprop):
    """Compute new kernels after an `Erf` layer."""
    var1, nngp, var2, ntk, _, is_height_width = kernels
    _var1_denom = 1 + 2 * var1
    _var2_denom = None if var2 is None else 1 + 2 * var2
    prod = _get_var_prod(_var1_denom, nngp, _var2_denom)

    dot_sigma = 4 / (np.pi * np.sqrt(prod - 4 * nngp**2))
    if ntk is not None:
        ntk *= dot_sigma

    nngp = _arcsin(2 * nngp / np.sqrt(prod), do_backprop) * 2 / np.pi

    var1 = np.arcsin(2 * var1 / _var1_denom) * 2 / np.pi
    if var2 is not None:
        var2 = np.arcsin(2 * var2 / _var2_denom) * 2 / np.pi

    return Kernel(var1, nngp, var2, ntk, False, is_height_width)
예제 #18
0
def _flip_height_width(kernels):
    """Flips the order of spatial axes in the covariance matrices.

  Args:
    kernels: a `Kernel` object.

  Returns:
    A `Kernel` object with `height` and `width` axes order flipped in
    all covariance matrices. For example, if `kernels.nngp` has shape
    `[batch_size_1, batch_size_2, height, height, width, width]`, then
    `_flip_height_width(kernels).nngp` has shape
    `[batch_size_1, batch_size_2, width, width, height, height]`.
  """
    var1, nngp, var2, ntk, is_gaussian, is_height_width = kernels
    var1 = np.transpose(var1, (0, 2, 1))
    var2 = np.transpose(var2, (0, 2, 1)) if var2 is not None else var2
    nngp = np.transpose(nngp, (0, 1, 4, 5, 2, 3))
    ntk = np.transpose(ntk, (0, 1, 4, 5, 2, 3)) if ntk is not None else ntk
    return Kernel(var1, nngp, var2, ntk, is_gaussian, not is_height_width)
예제 #19
0
    def ker_fun(kernels):
        var1, nngp, var2, ntk, is_gaussian, _ = kernels

        pixel_axes = tuple(range(2, nngp.ndim))
        nngp = np.mean(nngp, axis=pixel_axes)
        ntk = np.mean(ntk, axis=pixel_axes) if _is_array(ntk) else ntk

        if var2 is None:
            var1 = np.diagonal(nngp)
        else:
            # TODO(romann)
            warnings.warn(
                'Pooling for different inputs `x1` and `x2` is not '
                'implemented and will only work if there are no '
                'nonlinearities in the network anywhere after the pooling '
                'layer. `var1` and `var2` will have wrong values. '
                'This will be fixed soon.')

        return Kernel(var1, nngp, var2, ntk, is_gaussian, True)
예제 #20
0
    def get_sampled_kernel(x1, x2, key, n_samples):
        if x2 is not None:
            assert x1.shape[1:] == x2.shape[1:]

        if key.shape == (2, ):
            key = random.split(key, n_samples)
        elif n_samples is not None:
            raise ValueError('Got set `n_samples=%d` and %d RNG keys.' %
                             (n_samples, key.shape[0]))

        ker_sampled = Kernel(var1=None,
                             nngp=0. if compute_nngp else None,
                             var2=None,
                             ntk=0. if compute_ntk else None,
                             is_gaussian=None,
                             is_height_width=None)
        for subkey in key:
            ker_sampled += ker_fun_sample_once(x1, x2, subkey)

        return ker_sampled / len(key)
예제 #21
0
 def ker_fun(x1, x2, params):
     return Kernel(None, nngp_fun(x1, x2, params), None,
                   ntk_fun(x1, x2, params), None, None, None, None)
예제 #22
0
def _inputs_to_kernel(x1, x2, use_pooling, compute_ntk):
    """Transforms (batches of) inputs to a `Kernel`.

  This is a private method. Docstring and example are for internal reference.

   The kernel contains the empirical covariances between different inputs and
     their entries (pixels) necessary to compute the covariance of the Gaussian
     Process corresponding to an infinite Bayesian or gradient-flow-trained
     neural network.

   The smallest necessary number of covariance entries is tracked. For example,
     all networks are assumed to have i.i.d. weights along the channel / feature
     / logits dimensions, hence covariance between different entries along these
     dimensions is known to be 0 and is not tracked.

  Args:
    x1: a 2D `np.ndarray` of shape `[batch_size_1, n_features]` (dense
      network) or 4D of shape `[batch_size_1, height, width, channels]`
      (conv-nets).
    x2: an optional `np.ndarray` with the same shape as `x1` apart
      from possibly different leading batch size. `None` means
      `x2 == x1`.
    use_pooling: a boolean, indicating whether pooling will be used somewhere in
      the model. If so, more covariance entries need to be tracked. Is set
      automatically based on the network topology. Specifically, is set to
      `False` if a `serial` or `parallel` networks contain a `Flatten` layer
      and no pooling layers (`AvgPool` or `GlobalAvgPool`). Has no effect for
      non-convolutional models.
    compute_ntk: a boolean, `True` to compute both NTK and NNGP kernels,
        `False` to only compute NNGP.

    Example:
      ```python
          >>> x = np.ones((10, 32, 16, 3))
          >>> _inputs_to_kernel(x, None, use_pooling=True,
          >>>                   compute_ntk=True).ntk.shape
          (10, 10, 32, 32, 16, 16)
          >>> _inputs_to_kernel(x, None, use_pooling=False,
          >>>                   compute_ntk=True).ntk.shape
          (10, 10, 32, 16)
          >>> x1 = np.ones((10, 128))
          >>> x2 = np.ones((20, 128))
          >>> _inputs_to_kernel(x, None, use_pooling=True,
          >>>                   compute_ntk=False).nngp.shape
          (10, 20)
          >>> _inputs_to_kernel(x, None, use_pooling=False,
          >>>                   compute_ntk=False).nngp.shape
          (10, 20)
          >>> _inputs_to_kernel(x, None, use_pooling=False,
          >>>                   compute_ntk=False).ntk
          None
      ```

  Returns:
    a `Kernel` object.
  """
    x1 = x1.astype(xla_bridge.canonicalize_dtype(np.float64))
    var1 = _get_variance(x1)

    if x2 is None:
        x2 = x1
        var2 = None
    else:
        if x1.shape[1:] != x2.shape[1:]:
            raise ValueError(
                '`x1` and `x2` are expected to be batches of'
                ' inputs with the same shape (apart from the batch size),'
                ' got %s and %s.' % (str(x1.shape), str(x2.shape)))

        x2 = x2.astype(xla_bridge.canonicalize_dtype(np.float64))
        var2 = _get_variance(x2)

    if use_pooling and x1.ndim == 4:
        x2 = np.expand_dims(x2, -1)
        nngp = np.dot(x1, x2) / x1.shape[-1]
        nngp = np.transpose(np.squeeze(nngp, -1), (0, 3, 1, 4, 2, 5))

    elif x1.ndim == 4 or x1.ndim == 2:
        nngp = _batch_uncentered_covariance(x1, x2)

    else:
        raise ValueError('Inputs must be 2D or 4D `np.ndarray`s of shape '
                         '`[batch_size, n_features]` or '
                         '`[batch_size, height, width, channels]`, '
                         'got %s.' % str(x1.shape))

    ntk = 0. if compute_ntk else None
    is_gaussian = False
    is_height_width = True
    return Kernel(var1, nngp, var2, ntk, is_gaussian, is_height_width)
예제 #23
0
def _set_cov2_is_none(k: Kernel) -> Kernel:
    return k.replace(cov2=None)