Beispiel #1
0
  def test_nested_parallel(self, same_inputs, kernel_type):
    platform = default_backend()
    rtol = RTOL if platform != 'tpu' else 0.05

    rng = random.PRNGKey(0)
    (input_key1,
     input_key2,
     input_key3,
     input_key4,
     mask_key,
     mc_key) = random.split(rng, 6)

    x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 5))
    x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 2, 2, 2))
    x1_3, x2_3 = _get_inputs(input_key3, same_inputs, (BATCH_SIZE, 2, 2, 3))
    x1_4, x2_4 = _get_inputs(input_key4, same_inputs, (BATCH_SIZE, 3, 4))

    m1_key, m2_key, m3_key, m4_key = random.split(mask_key, 4)

    x1_1 = test_utils.mask(
        x1_1, mask_constant=-1, mask_axis=(1,), key=m1_key, p=0.5)
    x1_2 = test_utils.mask(
        x1_2, mask_constant=-1, mask_axis=(2, 3,), key=m2_key, p=0.5)
    if not same_inputs:
      x2_3 = test_utils.mask(
          x2_3, mask_constant=-1, mask_axis=(1, 3,), key=m3_key, p=0.5)
      x2_4 = test_utils.mask(
          x2_4, mask_constant=-1, mask_axis=(2,), key=m4_key, p=0.5)

    x1 = (((x1_1, x1_2), x1_3), x1_4)
    x2 = (((x2_1, x2_2), x2_3), x2_4) if not same_inputs else None

    N_in = 2 ** 7

    # We only include dropout on non-TPU backends, because it takes large N to
    # converge on TPU.
    dropout_or_id = stax.Dropout(0.9) if platform != 'tpu' else stax.Identity()

    init_fn, apply_fn, kernel_fn = stax.parallel(
        stax.parallel(
            stax.parallel(stax.Dense(N_in),
                          stax.serial(stax.Conv(N_in + 1, (2, 2)),
                                      stax.Flatten())),
            stax.serial(stax.Conv(N_in + 2, (2, 2)),
                        dropout_or_id,
                        stax.GlobalAvgPool())),
        stax.Conv(N_in + 3, (2,)))

    kernel_fn_empirical = nt.monte_carlo_kernel_fn(
        init_fn, apply_fn, mc_key, N_SAMPLES, implementation=2,
        vmap_axes=(((((0, 0), 0), 0), (((0, 0), 0), 0), {})
                   if platform == 'tpu' else None)
    )

    test_utils.assert_close_matrices(
        self,
        kernel_fn(x1, x2, get=kernel_type, mask_constant=-1),
        kernel_fn_empirical(x1, x2, get=kernel_type, mask_constant=-1),
        rtol)
Beispiel #2
0
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding,
             phi, strides, width, is_ntk, proj_into_2d, layer_norm,
             parameterization, use_dropout):
    fc = partial(stax.Dense,
                 W_std=W_std,
                 b_std=b_std,
                 parameterization=parameterization)
    conv = partial(stax.Conv,
                   filter_shape=filter_shape,
                   strides=strides,
                   padding=padding,
                   W_std=W_std,
                   b_std=b_std,
                   parameterization=parameterization)
    affine = conv(width) if is_conv else fc(width)
    rate = np.onp.random.uniform(0.5, 0.9)
    dropout = stax.Dropout(rate, mode='train')
    ave_pool = stax.AvgPool((2, 3), None,
                            'SAME' if padding == 'SAME' else 'CIRCULAR')
    ave_pool_or_identity = ave_pool if use_pooling else stax.Identity()
    dropout_or_identity = dropout if use_dropout else stax.Identity()
    layer_norm_or_identity = (stax.Identity() if layer_norm is None else
                              stax.LayerNorm(axis=layer_norm))
    res_unit = stax.serial(ave_pool_or_identity, phi, dropout_or_identity,
                           affine)
    if is_res:
        block = stax.serial(affine, stax.FanOut(2),
                            stax.parallel(stax.Identity(), res_unit),
                            stax.FanInSum(), layer_norm_or_identity)
    else:
        block = stax.serial(affine, res_unit, layer_norm_or_identity)

    if proj_into_2d == 'FLAT':
        proj_layer = stax.Flatten()
    elif proj_into_2d == 'POOL':
        proj_layer = stax.GlobalAvgPool()
    elif proj_into_2d.startswith('ATTN'):
        n_heads = int(np.sqrt(width))
        n_chan_val = int(np.round(float(width) / n_heads))
        fixed = proj_into_2d == 'ATTN_FIXED'
        proj_layer = stax.serial(
            stax.GlobalSelfAttention(width,
                                     n_chan_key=width,
                                     n_chan_val=n_chan_val,
                                     n_heads=n_heads,
                                     fixed=fixed,
                                     W_key_std=W_std,
                                     W_value_std=W_std,
                                     W_query_std=W_std,
                                     W_out_std=1.0,
                                     b_std=b_std), stax.Flatten())
    else:
        raise ValueError(proj_into_2d)
    readout = stax.serial(proj_layer, fc(1 if is_ntk else width))

    return stax.serial(block, readout)
    def test_mask_conv(self, same_inputs, get, mask_axis, mask_constant,
                       concat, proj, p, n, transpose):
        if isinstance(concat, int) and concat > n:
            raise absltest.SkipTest('Concatenation axis out of bounds.')

        test_utils.skip_test(self)
        if default_backend() == 'gpu' and n > 3:
            raise absltest.SkipTest('>=4D-CNN is not supported on GPUs.')

        width = 256
        n_samples = 256
        tol = 0.03
        key = random.PRNGKey(1)

        spatial_shape = ((1, 2, 3, 2, 1) if transpose else (15, 8, 9))[:n]
        filter_shape = ((2, 3, 1, 2, 1) if transpose else (7, 2, 3))[:n]
        strides = (2, 1, 3, 2, 3)[:n]
        spatial_spec = 'HWDZX'[:n]
        dimension_numbers = ('N' + spatial_spec + 'C', 'OI' + spatial_spec,
                             'N' + spatial_spec + 'C')

        x1 = np.cos(random.normal(key, (2, ) + spatial_shape + (2, )))
        x1 = test_utils.mask(x1, mask_constant, mask_axis, key, p)

        if same_inputs:
            x2 = None
        else:
            x2 = np.cos(random.normal(key, (4, ) + spatial_shape + (2, )))
            x2 = test_utils.mask(x2, mask_constant, mask_axis, key, p)

        def get_attn():
            return stax.GlobalSelfAttention(
                n_chan_out=width,
                n_chan_key=width,
                n_chan_val=int(np.round(float(width) / int(np.sqrt(width)))),
                n_heads=int(np.sqrt(width)),
            ) if proj == 'avg' else stax.Identity()

        conv = stax.ConvTranspose if transpose else stax.Conv

        nn = stax.serial(
            stax.FanOut(3),
            stax.parallel(
                stax.serial(
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='CIRCULAR',
                         W_std=1.5,
                         b_std=0.2),
                    stax.LayerNorm(axis=(1, -1)),
                    stax.Abs(),
                    stax.DotGeneral(rhs=0.9),
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='VALID',
                         W_std=1.2,
                         b_std=0.1),
                ),
                stax.serial(
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='SAME',
                         W_std=0.1,
                         b_std=0.3),
                    stax.Relu(),
                    stax.Dropout(0.7),
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='VALID',
                         W_std=0.9,
                         b_std=1.),
                ),
                stax.serial(
                    get_attn(),
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='CIRCULAR',
                         W_std=1.,
                         b_std=0.1),
                    stax.Erf(),
                    stax.Dropout(0.2),
                    stax.DotGeneral(rhs=0.7),
                    conv(dimension_numbers=dimension_numbers,
                         out_chan=width,
                         strides=strides,
                         filter_shape=filter_shape,
                         padding='VALID',
                         W_std=1.,
                         b_std=0.1),
                )),
            (stax.FanInSum() if concat is None else stax.FanInConcat(concat)),
            get_attn(),
            {
                'avg': stax.GlobalAvgPool(),
                'sum': stax.GlobalSumPool(),
                'flatten': stax.Flatten(),
            }[proj],
        )

        if get == 'nngp':
            init_fn, apply_fn, kernel_fn = stax.serial(
                nn, stax.Dense(width, 1., 0.))
        elif get == 'ntk':
            init_fn, apply_fn, kernel_fn = stax.serial(nn,
                                                       stax.Dense(1, 1., 0.))
        else:
            raise ValueError(get)

        kernel_fn_mc = nt.monte_carlo_kernel_fn(
            init_fn,
            apply_fn,
            key,
            n_samples,
            device_count=0 if concat in (0, -n) else -1,
            implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION,
            vmap_axes=None if concat in (0, -n) else 0,
        )

        kernel_fn = jit(kernel_fn, static_argnames='get')
        exact = kernel_fn(x1, x2, get, mask_constant=mask_constant)
        empirical = kernel_fn_mc(x1, x2, get=get, mask_constant=mask_constant)
        test_utils.assert_close_matrices(self, empirical, exact, tol)
Beispiel #4
0
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding,
             phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm,
             parameterization, s, use_dropout):

  if is_conv:
    # Select a random filter order.
    default_filter_spec = 'HW'
    filter_specs = [''.join(p) for p in itertools.permutations('HWIO')]
    filter_spec = prandom.choice(filter_specs)
    filter_shape = tuple(filter_shape[default_filter_spec.index(c)]
                         for c in filter_spec if c in default_filter_spec)
    strides = tuple(strides[default_filter_spec.index(c)]
                    for c in filter_spec if c in default_filter_spec)

    # Select the activation order.
    default_spec = 'NHWC'
    if default_backend() == 'tpu':
      # Keep batch dimension leading for TPU for batching to work.
      specs = ['N' + ''.join(p) for p in itertools.permutations('CHW')]
    else:
      specs = [''.join(p) for p in itertools.permutations('NCHW')]
    spec = prandom.choice(specs)
    input_shape = tuple(INPUT_SHAPE[default_spec.index(c)] for c in spec)

  else:
    input_shape = (INPUT_SHAPE[0], onp.prod(INPUT_SHAPE[1:]))
    if default_backend() == 'tpu':
      spec = 'NC'
    else:
      spec = prandom.choice(['NC', 'CN'])
      if spec.index('N') == 1:
        input_shape = input_shape[::-1]

    filter_spec = None

  dimension_numbers = (spec, filter_spec, spec)
  batch_axis, channel_axis = spec.index('N'), spec.index('C')

  spec_fc = ''.join(c for c in spec if c in ('N', 'C'))
  batch_axis_fc, channel_axis_fc = spec_fc.index('N'), spec_fc.index('C')

  if not is_conv:
    batch_axis = batch_axis_fc
    channel_axis = channel_axis_fc

  if layer_norm:
    layer_norm = tuple(spec.index(c) for c in layer_norm)

  def fc(out_dim, s):
    return stax.Dense(
        out_dim=out_dim,
        W_std=W_std,
        b_std=b_std,
        parameterization=parameterization,
        s=s,
        batch_axis=batch_axis_fc,
        channel_axis=channel_axis_fc
    )

  def conv(out_chan, s):
    return stax.Conv(
        out_chan=out_chan,
        filter_shape=filter_shape,
        strides=strides,
        padding=padding,
        W_std=W_std,
        b_std=b_std,
        dimension_numbers=dimension_numbers,
        parameterization=parameterization,
        s=s
    )

  affine = conv(width, (s, s)) if is_conv else fc(width, (s, s))
  affine_bottom = conv(width, (1, s)) if is_conv else fc(width, (1, s))

  rate = onp.random.uniform(0.5, 0.9)
  dropout = stax.Dropout(rate, mode='train')

  if pool_type == 'AVG':
    pool_fn = stax.AvgPool
    global_pool_fn = stax.GlobalAvgPool
  elif pool_type == 'SUM':
    pool_fn = stax.SumPool
    global_pool_fn = stax.GlobalSumPool
  else:
    raise ValueError(pool_type)

  if use_pooling:
    pool_or_identity = pool_fn((2, 3),
                               None,
                               'SAME' if padding == 'SAME' else 'CIRCULAR',
                               batch_axis=batch_axis,
                               channel_axis=channel_axis)
  else:
    pool_or_identity = stax.Identity()
  dropout_or_identity = dropout if use_dropout else stax.Identity()
  layer_norm_or_identity = (stax.Identity() if layer_norm is None else
                            stax.LayerNorm(axis=layer_norm,
                                           batch_axis=batch_axis,
                                           channel_axis=channel_axis))
  res_unit = stax.serial(dropout_or_identity, affine, pool_or_identity)
  if is_res:
    block = stax.serial(
        affine_bottom,
        stax.FanOut(2),
        stax.parallel(stax.Identity(),
                      res_unit),
        stax.FanInSum(),
        layer_norm_or_identity,
        phi)
  else:
    block = stax.serial(
        affine_bottom,
        res_unit,
        layer_norm_or_identity,
        phi)

  if proj_into_2d == 'FLAT':
    proj_layer = stax.Flatten(batch_axis, batch_axis_fc)
  elif proj_into_2d == 'POOL':
    proj_layer = global_pool_fn(batch_axis, channel_axis)
  elif proj_into_2d.startswith('ATTN'):
    n_heads = int(np.sqrt(width))
    n_chan_val = int(np.round(float(width) / n_heads))
    proj_layer = stax.serial(
        stax.GlobalSelfAttention(
            n_chan_out=width,
            n_chan_key=width,
            n_chan_val=n_chan_val,
            n_heads=n_heads,
            linear_scaling=True,
            W_key_std=W_std,
            W_value_std=W_std,
            W_query_std=W_std,
            W_out_std=1.0,
            b_std=b_std,
            batch_axis=batch_axis,
            channel_axis=channel_axis),
        stax.Flatten(batch_axis, batch_axis_fc))
  else:
    raise ValueError(proj_into_2d)

  readout = stax.serial(proj_layer,
                        fc(1 if is_ntk else width, (s, 1 if is_ntk else s)))

  device_count = -1 if spec.index('N') == 0 else 0

  net = stax.serial(block, readout)
  return net, input_shape, device_count, channel_axis_fc
Beispiel #5
0
def bann_model(
    W_std,
    b_std,
    first_layer_width,
    second_layer_width,
    subNN_num,
    keep_rate,
    activation,
    parameterization
):
    """Construct fully connected NN model and infinite width NTK & NNGP kernel
       function.

    Args:
        W_std (float): Weight standard deviation.
        b_std (float): Bias standard deviation.
        first_layer_width (int): First Hidden layer width.
        second_layer_width (int): Second Hidden layer width.
        subNN_num (int) : Number of sub neural networks in the architecture
        keep_rate (float): 1 - Dropout rate.
        activation (string): Activation function string, 'erf' or 'relu'.
        parameterization (string): Parameterization string, 'ntk' or 'standard'.

    Returns:
        `(init_fn, apply_fn, kernel_fn)`
    """
    act = activation_fn(activation)

    # multi-task learning
    # Computational Skeleton Block
    CSB = stax.serial(
        stax.FanOut(subNN_num),
        stax.parallel(
            stax.serial(
                Dense(first_layer_width, W_std, b_std, parameterization=parameterization), act(),
                Dense(second_layer_width, W_std, b_std, parameterization=parameterization), act(),
                stax.Dropout(keep_rate)
            ),
            stax.serial(
                Dense(first_layer_width, W_std, b_std, parameterization=parameterization), act(),
                Dense(2 * second_layer_width, W_std, b_std, parameterization=parameterization), act(),
                stax.Dropout(keep_rate)
            ),
            stax.serial(
                Dense(first_layer_width, W_std, b_std, parameterization=parameterization), act(),
                Dense(3 * second_layer_width, W_std, b_std, parameterization=parameterization), act(),
                stax.Dropout(keep_rate)
            ),
            stax.serial(
                Dense(first_layer_width, W_std, b_std, parameterization=parameterization), act(),
                Dense(4 * second_layer_width, W_std, b_std, parameterization=parameterization), act(),
                stax.Dropout(keep_rate)
            ),
            stax.serial(
                Dense(first_layer_width, W_std, b_std, parameterization=parameterization), act(),
                Dense(5 * second_layer_width, W_std, b_std, parameterization=parameterization), act(),
                stax.Dropout(keep_rate)
            ),
            stax.serial(
                Dense(first_layer_width, W_std, b_std, parameterization=parameterization), act(),
                Dense(6 * second_layer_width, W_std, b_std, parameterization=parameterization), act(),
                stax.Dropout(keep_rate)
            ),
            stax.serial(
                Dense(first_layer_width, W_std, b_std, parameterization=parameterization), act(),
                Dense(7 * second_layer_width, W_std, b_std, parameterization=parameterization), act(),
                stax.Dropout(keep_rate)
            ),
            stax.serial(
                Dense(first_layer_width, W_std, b_std, parameterization=parameterization), act(),
                Dense(8 * second_layer_width, W_std, b_std, parameterization=parameterization), act(),
                stax.Dropout(keep_rate)
            ),
            stax.serial(
                Dense(first_layer_width, W_std, b_std, parameterization=parameterization), act(),
                Dense(9 * second_layer_width, W_std, b_std, parameterization=parameterization), act(),
                stax.Dropout(keep_rate)
            ),
            stax.serial(
                Dense(first_layer_width, W_std, b_std, parameterization=parameterization), act(),
                Dense(10 * second_layer_width, W_std, b_std, parameterization=parameterization), act(),
                stax.Dropout(keep_rate)
            )
        ),
        stax.FanInConcat()
    )

    Additive = stax.serial(
        stax.FanOut(2),
        stax.parallel(
            stax.serial(
                CSB,
                stax.Dropout(keep_rate)
            ),
            stax.serial(
                CSB,
                stax.Dropout(keep_rate)
            )
        ),
        stax.FanInConcat()
    )

    init_fn, apply_fn, kernel_fn = stax.serial(
        Additive,
        Dense(1, W_std, b_std, parameterization=parameterization)
    )

    apply_fn = jit(apply_fn)

    return init_fn, apply_fn, kernel_fn
Beispiel #6
0
def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding,
             phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm,
             parameterization, use_dropout):

  if is_conv:
    # Select a random dimension order.
    default_spec = 'NHWC'
    if xla_bridge.get_backend().platform == 'tpu':
      # Keep batch dimension leading for TPU for batching to work.
      specs = ['NHWC', 'NHCW', 'NCHW']
    else:
      specs = ['NHWC', 'NHCW', 'NCHW', 'CHWN', 'CHNW', 'CNHW']
    spec = prandom.choice(specs)
    input_shape = tuple(INPUT_SHAPE[default_spec.index(c)] for c in spec)

    if layer_norm:
      layer_norm = tuple(spec.index(c) for c in layer_norm)

  else:
    # Only `NC` dimension order is supported and is enforced by layers.
    spec = None
    input_shape = INPUT_SHAPE
    if layer_norm:
      layer_norm = prandom.choice([(1,), (-1,)])

  dimension_numbers = (spec, 'HWIO', spec)

  fc = partial(
      stax.Dense, W_std=W_std, b_std=b_std, parameterization=parameterization)

  def conv(out_chan): return stax.GeneralConv(
      dimension_numbers=dimension_numbers,
      out_chan=out_chan,
      filter_shape=filter_shape,
      strides=strides,
      padding=padding,
      W_std=W_std,
      b_std=b_std,
      parameterization=parameterization
  )
  affine = conv(width) if is_conv else fc(width)

  spec = dimension_numbers[-1]

  rate = np.onp.random.uniform(0.5, 0.9)
  dropout = stax.Dropout(rate, mode='train')

  if pool_type == 'AVG':
    pool_fn = stax.AvgPool
    globalPool_fn = stax.GlobalAvgPool
  elif pool_type == 'SUM':
    pool_fn = stax.SumPool
    globalPool_fn = stax.GlobalSumPool

  if use_pooling:
    pool_or_identity = pool_fn((2, 3),
                               None,
                               'SAME' if padding == 'SAME' else 'CIRCULAR',
                               spec=spec)
  else:
    pool_or_identity = stax.Identity()
  dropout_or_identity = dropout if use_dropout else stax.Identity()
  layer_norm_or_identity = (stax.Identity() if layer_norm is None
                            else stax.LayerNorm(axis=layer_norm, spec=spec))
  res_unit = stax.serial(pool_or_identity, phi, dropout_or_identity, affine)
  if is_res:
    block = stax.serial(
        affine,
        stax.FanOut(2),
        stax.parallel(stax.Identity(),
                      res_unit),
        stax.FanInSum(),
        layer_norm_or_identity)
  else:
    block = stax.serial(
        affine,
        res_unit,
        layer_norm_or_identity)

  if proj_into_2d == 'FLAT':
    proj_layer = stax.Flatten(spec=spec)
  elif proj_into_2d == 'POOL':
    proj_layer = globalPool_fn(spec=spec)
  elif proj_into_2d.startswith('ATTN'):
    n_heads = int(np.sqrt(width))
    n_chan_val = int(np.round(float(width) / n_heads))
    fixed = proj_into_2d == 'ATTN_FIXED'
    proj_layer = stax.serial(
        stax.GlobalSelfAttention(
            n_chan_out=width,
            n_chan_key=width,
            n_chan_val=n_chan_val,
            n_heads=n_heads,
            fixed=fixed,
            W_key_std=W_std,
            W_value_std=W_std,
            W_query_std=W_std,
            W_out_std=1.0,
            b_std=b_std,
            spec=spec), stax.Flatten(spec=spec))
  else:
    raise ValueError(proj_into_2d)
  readout = stax.serial(proj_layer, fc(1 if is_ntk else width))

  return stax.serial(block, readout), input_shape
Beispiel #7
0
  def test_kwargs(self, do_batch, mode):
    rng = random.PRNGKey(1)

    x_train = random.normal(rng, (8, 7, 10))
    x_test = random.normal(rng, (4, 7, 10))
    y_train = random.normal(rng, (8, 1))

    rng_train, rng_test = random.split(rng, 2)

    pattern_train = random.normal(rng, (8, 7, 7))
    pattern_test = random.normal(rng, (4, 7, 7))

    init_fn, apply_fn, kernel_fn = stax.serial(
        stax.Dense(8),
        stax.Relu(),
        stax.Dropout(rate=0.4),
        stax.Aggregate(),
        stax.GlobalAvgPool(),
        stax.Dense(1)
    )

    kw_dd = dict(pattern=(pattern_train, pattern_train))
    kw_td = dict(pattern=(pattern_test, pattern_train))
    kw_tt = dict(pattern=(pattern_test, pattern_test))

    if mode == 'mc':
      kernel_fn = monte_carlo_kernel_fn(init_fn, apply_fn, rng, 2,
                                        batch_size=2 if do_batch else 0)

    elif mode == 'empirical':
      kernel_fn = empirical_kernel_fn(apply_fn)
      if do_batch:
        raise absltest.SkipTest('Batching of empirical kernel is not '
                                'implemented with keyword arguments.')

      for kw in (kw_dd, kw_td, kw_tt):
        kw.update(dict(params=init_fn(rng, x_train.shape)[1],
                       get=('nngp', 'ntk')))

      kw_dd.update(dict(rng=(rng_train, None)))
      kw_td.update(dict(rng=(rng_test, rng_train)))
      kw_tt.update(dict(rng=(rng_test, None)))

    elif mode == 'analytic':
      if do_batch:
        kernel_fn = batch.batch(kernel_fn, batch_size=2)

    else:
      raise ValueError(mode)

    k_dd = kernel_fn(x_train, None, **kw_dd)
    k_td = kernel_fn(x_test, x_train, **kw_td)
    k_tt = kernel_fn(x_test, None, **kw_tt)

    # Infinite time NNGP/NTK.
    predict_fn_gp = predict.gp_inference(k_dd, y_train)
    out_gp = predict_fn_gp(k_test_train=k_td, nngp_test_test=k_tt.nngp)

    if mode == 'empirical':
      for kw in (kw_dd, kw_td, kw_tt):
        kw.pop('get')

    predict_fn_ensemble = predict.gradient_descent_mse_ensemble(kernel_fn,
                                                                x_train,
                                                                y_train,
                                                                **kw_dd)
    out_ensemble = predict_fn_ensemble(x_test=x_test, compute_cov=True, **kw_tt)
    self.assertAllClose(out_gp, out_ensemble)

    # Finite time NTK test.
    predict_fn_mse = predict.gradient_descent_mse(k_dd.ntk, y_train)
    out_mse = predict_fn_mse(t=1.,
                             fx_train_0=None,
                             fx_test_0=0.,
                             k_test_train=k_td.ntk)
    out_ensemble = predict_fn_ensemble(t=1.,
                                       get='ntk',
                                       x_test=x_test,
                                       compute_cov=False,
                                       **kw_tt)
    self.assertAllClose(out_mse, out_ensemble)

    # Finite time NNGP train.
    predict_fn_mse = predict.gradient_descent_mse(k_dd.nngp, y_train)
    out_mse = predict_fn_mse(t=2.,
                             fx_train_0=0.,
                             fx_test_0=None,
                             k_test_train=k_td.nngp)
    out_ensemble = predict_fn_ensemble(t=2.,
                                       get='nngp',
                                       x_test=None,
                                       compute_cov=False,
                                       **kw_dd)
    self.assertAllClose(out_mse, out_ensemble)