def _kernel_fns(key,
                input_shape,
                network,
                out_logits,
                diagonal_axes,
                trace_axes,
                vmap_axes=None):
  init_fn, f, _ = _build_network(input_shape, network, out_logits)
  _, params = init_fn(key, (-1,) + input_shape)
  implicit_kernel_fn = jit(empirical._empirical_implicit_ntk_fn(f,
                                                                trace_axes,
                                                                diagonal_axes,
                                                                vmap_axes))
  direct_kernel_fn = jit(empirical._empirical_direct_ntk_fn(f,
                                                            trace_axes,
                                                            diagonal_axes,
                                                            vmap_axes))

  nngp_kernel_fn = jit(empirical.empirical_nngp_fn(f,
                                                   trace_axes,
                                                   diagonal_axes))

  return (partial(implicit_kernel_fn, params=params),
          partial(direct_kernel_fn, params=params),
          partial(nngp_kernel_fn, params=params))
Beispiel #2
0
    def test_parallel_in_out_empirical(self, same_inputs):
        test_utils.stub_out_pmap(batch, 2)
        rng = random.PRNGKey(0)
        input_key1, input_key2, net_key = random.split(rng, 3)

        x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 4, 10))
        x2_1, x2_2, x2_3 = random.normal(input_key2, (3, 8, 10))

        x1 = (x1_1, (x1_2, x1_3))
        x2 = (x2_1, (x2_2, x2_3))

        def net(N_out):
            return stax.parallel(
                stax.Dense(N_out),
                stax.parallel(stax.Dense(N_out + 1), stax.Dense(N_out + 2)))

        # Check NNGP.
        init_fn, apply_fn, _ = net(WIDTH)
        _, params = init_fn(net_key, ((-1, 10), ((-1, 10), (-1, 10))))

        kernel_fn = jit(empirical.empirical_nngp_fn(apply_fn))
        batch_kernel_fn = jit(batch.batch(kernel_fn, 2))

        test_utils.assert_close_matrices(self, kernel_fn(x1, x2, params),
                                         batch_kernel_fn(x1, x2, params), RTOL)

        # Check NTK.
        init_fn, apply_fn, _ = stax.serial(net(WIDTH), net(1))
        _, params = init_fn(net_key, ((-1, 10), ((-1, 10), (-1, 10))))

        kernel_fn = jit(empirical.empirical_ntk_fn(apply_fn))
        batch_kernel_fn = jit(batch.batch(kernel_fn, 2))

        test_utils.assert_close_matrices(self, kernel_fn(x1, x2, params),
                                         batch_kernel_fn(x1, x2, params), RTOL)
  def test_parallel_in_out(self, same_inputs):
    rng = random.PRNGKey(0)
    input_key1, input_key2, net_key = random.split(rng, 3)

    x1_1, x1_2 = np.split(random.normal(input_key1, (3, 21)), (10,), axis=1)
    x2_1, x2_2 = np.split(random.normal(input_key2, (4, 21)), (10,), axis=1)

    x1 = (x1_1, x1_2)
    x2 = (x2_1, x2_2) if not same_inputs else None

    def layer(N_out):
      return stax.parallel(stax.Dense(N_out), stax.Dense(N_out + 1))

    init_fn, apply_fn, _ = stax.serial(layer(1024), layer(1))

    _, params = init_fn(net_key, (x1_1.shape, x1_2.shape))

    implicit_kernel_fn = jit(empirical._empirical_implicit_ntk_fn(apply_fn))
    direct_kernel_fn = jit(empirical._empirical_direct_ntk_fn(apply_fn))
    implicit_batched_kernel_fn = jit(empirical._empirical_implicit_ntk_fn(
        apply_fn, vmap_axes=(0, 0)))
    direct_batched_kernel_fn = jit(empirical._empirical_direct_ntk_fn(
        apply_fn, vmap_axes=(0, 0)))

    k_direct = direct_kernel_fn(x1, x2, params)

    self.assertAllClose(k_direct, implicit_kernel_fn(x1, x2, params))
    self.assertAllClose(k_direct, direct_batched_kernel_fn(x1, x2, params))
    self.assertAllClose(k_direct, implicit_batched_kernel_fn(x1, x2, params))

    nngp_kernel_fn = jit(empirical.empirical_nngp_fn(apply_fn))
    nngp = nngp_kernel_fn(x1, x2, params)
    self.assertEqual(len(nngp), 2)
    self.assertEqual(nngp[0].shape, (3, 3 if same_inputs else 4))
    self.assertEqual(nngp[1].shape, (3, 3 if same_inputs else 4))