def serial_fn_kernel(k: Kernel, *args, **kwargs) -> Kernel: n1, n2 = k.nngp.shape[:2] (n1_batches, n1_batch_size, n2_batches, n2_batch_size) = _get_n_batches_and_batch_sizes(n1, n2, batch_size, device_count) n1s = np.arange(0, n1, n1_batch_size) n2s = np.arange(0, n2, n2_batch_size) def row_fn(_, n1: int) -> Tuple[int, Kernel]: return _, _scan(col_fn, n1, n2s)[1] def col_fn(n1: int, n2: int) -> Tuple[int, Kernel]: # NOTE(schsam): If we end up wanting to enable jit-of-batch then we will # probably have to change this to dynamic slicing. n1_slice = slice(n1, n1 + n1_batch_size) n2_slice = slice(n2, n2 + n2_batch_size) in_kernel = k.slice(n1_slice, n2_slice) return n1, kernel_fn(in_kernel, *args, **kwargs) cov2_is_none = k.cov2 is None _, k = _scan(row_fn, 0, n1s) if cov2_is_none: k = k.replace(cov2=None) return flatten(k, cov2_is_none)
def _reshape_kernel_for_pmap(k: Kernel, device_count: int, n1_per_device: int) -> Kernel: cov2 = k.cov2 if cov2 is None: cov2 = k.cov1 cov2 = np.broadcast_to(cov2, (device_count,) + cov2.shape) mask2 = k.mask2 if mask2 is None and k.mask1 is not None: mask2 = k.mask1 if mask2 is not None: mask2 = np.broadcast_to(mask2, (device_count,) + mask2.shape) x1_is_x2 = np.broadcast_to(k.x1_is_x2, (device_count,) + k.x1_is_x2.shape) nngp, ntk, cov1 = [ np.reshape(x, (device_count, n1_per_device,) + x.shape[1:]) for x in (k.nngp, k.ntk, k.cov1)] return k.replace( nngp=nngp, ntk=ntk, cov1=cov1, cov2=cov2, x1_is_x2=x1_is_x2, shape1=(n1_per_device,) + k.shape1[1:], mask2=mask2)
def serial_fn_kernel(k: Kernel, *args, **kwargs) -> Kernel: n1, n2 = k.nngp.shape[:2] (n1_batches, n1_batch_size, n2_batches, n2_batch_size) = _get_n_batches_and_batch_sizes( n1, n2, batch_size, device_count) n1s = np.arange(0, n1, n1_batch_size) n2s = np.arange(0, n2, n2_batch_size) kwargs_np1 = {} kwargs_np2 = {} kwargs_other = {} for key, v in kwargs.items(): if _is_np_ndarray(v): assert isinstance(v, tuple) and len(v) == 2 v1 = np.reshape(v[0], ( n1_batches, n1_batch_size, ) + v[0].shape[1:]) v2 = np.reshape(v[1], ( n2_batches, n2_batch_size, ) + v[1].shape[1:]) kwargs_np1[key] = v1 kwargs_np2[key] = v2 else: kwargs_other[key] = v def row_fn(_, n1): return _, _scan(col_fn, n1, (n2s, kwargs_np2))[1] def col_fn(n1, n2): # NOTE(schsam): If we end up wanting to enable jit-of-batch then we will # probably have to change this to dynamic slicing. n1, kwargs1 = n1 n2, kwargs2 = n2 kwargs_merge = { **kwargs_other, **dict((key, (kwargs1[key], kwargs2[key])) for key in kwargs1) } n1_slice = slice(n1, n1 + n1_batch_size) n2_slice = slice(n2, n2 + n2_batch_size) in_kernel = k.slice(n1_slice, n2_slice) return (n1, kwargs1), kernel_fn(in_kernel, *args, **kwargs_merge) cov2_is_none = k.cov2 is None _, k = _scan(row_fn, 0, (n1s, kwargs_np1)) if cov2_is_none: k = k.replace(cov2=None) return flatten(k, cov2_is_none)
def _set_cov2_is_none(k: Kernel) -> Kernel: return k.replace(cov2=None)