Example #1
0
  def serial_fn_kernel(k: Kernel, *args, **kwargs) -> Kernel:
    n1, n2 = k.nngp.shape[:2]
    (n1_batches, n1_batch_size,
     n2_batches, n2_batch_size) = _get_n_batches_and_batch_sizes(n1, n2,
                                                                 batch_size,
                                                                 device_count)

    n1s = np.arange(0, n1, n1_batch_size)
    n2s = np.arange(0, n2, n2_batch_size)

    def row_fn(_, n1: int) -> Tuple[int, Kernel]:
      return _, _scan(col_fn, n1, n2s)[1]

    def col_fn(n1: int, n2: int) -> Tuple[int, Kernel]:
      # NOTE(schsam): If we end up wanting to enable jit-of-batch then we will
      # probably have to change this to dynamic slicing.
      n1_slice = slice(n1, n1 + n1_batch_size)
      n2_slice = slice(n2, n2 + n2_batch_size)
      in_kernel = k.slice(n1_slice, n2_slice)
      return n1, kernel_fn(in_kernel, *args, **kwargs)

    cov2_is_none = k.cov2 is None
    _, k = _scan(row_fn, 0, n1s)
    if cov2_is_none:
      k = k.replace(cov2=None)
    return flatten(k, cov2_is_none)
Example #2
0
def _reshape_kernel_for_pmap(k: Kernel,
                             device_count: int,
                             n1_per_device: int) -> Kernel:
  cov2 = k.cov2
  if cov2 is None:
    cov2 = k.cov1
  cov2 = np.broadcast_to(cov2, (device_count,) + cov2.shape)

  mask2 = k.mask2
  if mask2 is None and k.mask1 is not None:
    mask2 = k.mask1
  if mask2 is not None:
    mask2 = np.broadcast_to(mask2, (device_count,) + mask2.shape)

  x1_is_x2 = np.broadcast_to(k.x1_is_x2, (device_count,) + k.x1_is_x2.shape)

  nngp, ntk, cov1 = [
      np.reshape(x, (device_count, n1_per_device,) + x.shape[1:]) for x in
      (k.nngp, k.ntk, k.cov1)]

  return k.replace(
      nngp=nngp,
      ntk=ntk,
      cov1=cov1,
      cov2=cov2,
      x1_is_x2=x1_is_x2,
      shape1=(n1_per_device,) + k.shape1[1:],
      mask2=mask2)
Example #3
0
    def serial_fn_kernel(k: Kernel, *args, **kwargs) -> Kernel:
        n1, n2 = k.nngp.shape[:2]
        (n1_batches, n1_batch_size, n2_batches,
         n2_batch_size) = _get_n_batches_and_batch_sizes(
             n1, n2, batch_size, device_count)

        n1s = np.arange(0, n1, n1_batch_size)
        n2s = np.arange(0, n2, n2_batch_size)

        kwargs_np1 = {}
        kwargs_np2 = {}

        kwargs_other = {}
        for key, v in kwargs.items():
            if _is_np_ndarray(v):
                assert isinstance(v, tuple) and len(v) == 2
                v1 = np.reshape(v[0], (
                    n1_batches,
                    n1_batch_size,
                ) + v[0].shape[1:])
                v2 = np.reshape(v[1], (
                    n2_batches,
                    n2_batch_size,
                ) + v[1].shape[1:])
                kwargs_np1[key] = v1
                kwargs_np2[key] = v2
            else:
                kwargs_other[key] = v

        def row_fn(_, n1):
            return _, _scan(col_fn, n1, (n2s, kwargs_np2))[1]

        def col_fn(n1, n2):
            # NOTE(schsam): If we end up wanting to enable jit-of-batch then we will
            # probably have to change this to dynamic slicing.
            n1, kwargs1 = n1
            n2, kwargs2 = n2
            kwargs_merge = {
                **kwargs_other,
                **dict((key, (kwargs1[key], kwargs2[key])) for key in kwargs1)
            }
            n1_slice = slice(n1, n1 + n1_batch_size)
            n2_slice = slice(n2, n2 + n2_batch_size)
            in_kernel = k.slice(n1_slice, n2_slice)
            return (n1, kwargs1), kernel_fn(in_kernel, *args, **kwargs_merge)

        cov2_is_none = k.cov2 is None
        _, k = _scan(row_fn, 0, (n1s, kwargs_np1))
        if cov2_is_none:
            k = k.replace(cov2=None)
        return flatten(k, cov2_is_none)
Example #4
0
def _set_cov2_is_none(k: Kernel) -> Kernel:
    return k.replace(cov2=None)