Example #1
0
def _reshape_kernel_for_pmap(k: Kernel, device_count: int,
                             n1_per_device: int) -> Kernel:
    cov2 = k.cov2
    if cov2 is None:
        cov2 = k.cov1
    cov2 = np.broadcast_to(cov2, (device_count, ) + cov2.shape)

    mask2 = k.mask2
    if mask2 is None and k.mask1 is not None:
        mask2 = k.mask1
    if mask2 is not None:
        mask2 = np.broadcast_to(mask2, (device_count, ) + mask2.shape)

    x1_is_x2 = np.broadcast_to(k.x1_is_x2, (device_count, ) + k.x1_is_x2.shape)

    nngp, ntk, cov1 = [
        np.reshape(x, (
            device_count,
            n1_per_device,
        ) + x.shape[1:]) for x in (k.nngp, k.ntk, k.cov1)
    ]

    return k.replace(nngp=nngp,
                     ntk=ntk,
                     cov1=cov1,
                     cov2=cov2,
                     x1_is_x2=x1_is_x2,
                     shape1=(n1_per_device, ) + k.shape1[1:],
                     mask2=mask2)
Example #2
0
    def get_state_0(fx_train_or_state_0, fx_test_0, fx_test_shape):
        if isinstance(fx_train_or_state_0, ODEState):
            fx_train_0 = fx_train_or_state_0.fx_train
            fx_test_0 = fx_train_or_state_0.fx_test
            qx_train_0 = fx_train_or_state_0.qx_train
            qx_test_0 = fx_train_or_state_0.qx_test
        else:
            fx_train_0 = fx_train_or_state_0
            qx_train_0 = qx_test_0 = None

        if fx_train_0 is None:
            fx_train_0 = np.zeros_like(y_train, dtype)
        else:
            fx_train_0 = np.broadcast_to(fx_train_0, y_train.shape)

        if fx_test_0 is not None:
            fx_test_0 = np.broadcast_to(fx_test_0, fx_test_shape)

        if momentum is None:
            if qx_train_0 is not None or qx_test_0 is not None:
                raise ValueError('Got passed momentum state variables, while '
                                 '`momentum is None`.')
        else:
            qx_train_0 = (np.zeros_like(y_train, dtype) if qx_train_0 is None
                          else np.broadcast_to(qx_train_0, y_train.shape))
            qx_test_0 = (None if fx_test_0 is None else
                         (np.zeros(fx_test_shape, dtype) if qx_test_0 is None
                          else np.broadcast_to(qx_test_0, fx_test_shape)))

        return ODEState(fx_train_0, fx_test_0, qx_train_0, qx_test_0)  # pytype: disable=wrong-arg-count
Example #3
0
 def broadcast(arg: np.ndarray) -> np.ndarray:
     if device_count == 0:
         return arg
     # If the argument has already been sharded, no need to broadcast it.
     if isinstance(arg,
                   ShardedDeviceArray) and arg.shape[0] == device_count:
         return arg
     return np.broadcast_to(arg, (device_count, ) + arg.shape)
 def broadcast(arg):
     return np.broadcast_to(arg, (2, ) + arg.shape)