def _reshape_kernel_for_pmap(k: Kernel, device_count: int, n1_per_device: int) -> Kernel: cov2 = k.cov2 if cov2 is None: cov2 = k.cov1 cov2 = np.broadcast_to(cov2, (device_count, ) + cov2.shape) mask2 = k.mask2 if mask2 is None and k.mask1 is not None: mask2 = k.mask1 if mask2 is not None: mask2 = np.broadcast_to(mask2, (device_count, ) + mask2.shape) x1_is_x2 = np.broadcast_to(k.x1_is_x2, (device_count, ) + k.x1_is_x2.shape) nngp, ntk, cov1 = [ np.reshape(x, ( device_count, n1_per_device, ) + x.shape[1:]) for x in (k.nngp, k.ntk, k.cov1) ] return k.replace(nngp=nngp, ntk=ntk, cov1=cov1, cov2=cov2, x1_is_x2=x1_is_x2, shape1=(n1_per_device, ) + k.shape1[1:], mask2=mask2)
def get_state_0(fx_train_or_state_0, fx_test_0, fx_test_shape): if isinstance(fx_train_or_state_0, ODEState): fx_train_0 = fx_train_or_state_0.fx_train fx_test_0 = fx_train_or_state_0.fx_test qx_train_0 = fx_train_or_state_0.qx_train qx_test_0 = fx_train_or_state_0.qx_test else: fx_train_0 = fx_train_or_state_0 qx_train_0 = qx_test_0 = None if fx_train_0 is None: fx_train_0 = np.zeros_like(y_train, dtype) else: fx_train_0 = np.broadcast_to(fx_train_0, y_train.shape) if fx_test_0 is not None: fx_test_0 = np.broadcast_to(fx_test_0, fx_test_shape) if momentum is None: if qx_train_0 is not None or qx_test_0 is not None: raise ValueError('Got passed momentum state variables, while ' '`momentum is None`.') else: qx_train_0 = (np.zeros_like(y_train, dtype) if qx_train_0 is None else np.broadcast_to(qx_train_0, y_train.shape)) qx_test_0 = (None if fx_test_0 is None else (np.zeros(fx_test_shape, dtype) if qx_test_0 is None else np.broadcast_to(qx_test_0, fx_test_shape))) return ODEState(fx_train_0, fx_test_0, qx_train_0, qx_test_0) # pytype: disable=wrong-arg-count
def broadcast(arg: np.ndarray) -> np.ndarray: if device_count == 0: return arg # If the argument has already been sharded, no need to broadcast it. if isinstance(arg, ShardedDeviceArray) and arg.shape[0] == device_count: return arg return np.broadcast_to(arg, (device_count, ) + arg.shape)
def broadcast(arg): return np.broadcast_to(arg, (2, ) + arg.shape)