def test_monte_carlo_generator(self, batch_size, device_count,
                                   store_on_device, get):
        test_utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model(
            8, 1)
        x3, x4, _, _, _, _ = _get_inputs_and_model(8, 1)

        log_n_max = 4
        n_samples = [2**k for k in range(log_n_max)]
        sample_generator = monte_carlo.monte_carlo_kernel_fn(
            init_fn, apply_fn, key, n_samples, batch_size, device_count,
            store_on_device)

        if get is None:
            samples_12 = sample_generator(x1, x2)
            samples_34 = sample_generator(x3, x4)

            count = 0
            for n, s_12, s_34 in zip(n_samples, samples_12, samples_34):
                sample_fn = monte_carlo.monte_carlo_kernel_fn(
                    init_fn, apply_fn, key, n, batch_size, device_count,
                    store_on_device)
                sample_12 = sample_fn(x1, x2)
                sample_34 = sample_fn(x3, x4)
                self.assertAllClose(s_12, sample_12)
                self.assertAllClose(s_12, s_34)
                self.assertAllClose(s_12, sample_34)
                count += 1

            self.assertEqual(log_n_max, count)

            ker_analytic_12 = stax_kernel_fn(x1, x2, ('nngp', 'ntk'))
            ker_analytic_34 = stax_kernel_fn(x3, x4, ('nngp', 'ntk'))

        else:
            samples_12 = sample_generator(x1, x2, get)
            samples_34 = sample_generator(x3, x4, get)

            count = 0
            for n, s_12, s_34 in zip(n_samples, samples_12, samples_34):
                sample_fn = monte_carlo.monte_carlo_kernel_fn(
                    init_fn, apply_fn, key, n, batch_size, device_count,
                    store_on_device)
                sample_12 = sample_fn(x1, x2, get)
                sample_34 = sample_fn(x3, x4, get)
                self.assertAllClose(s_12, sample_12)
                self.assertAllClose(s_12, s_34)
                self.assertAllClose(s_12, sample_34)
                count += 1

            self.assertEqual(log_n_max, count)

            ker_analytic_12 = stax_kernel_fn(x1, x2, get)
            ker_analytic_34 = stax_kernel_fn(x3, x4, get)

        self.assertAllClose(ker_analytic_12, s_12, atol=2., rtol=2.)
        self.assertAllClose(ker_analytic_12, ker_analytic_34)
Пример #2
0
    def test_parallel_in_out_mc(self, same_inputs, batch_size):
        rng = random.PRNGKey(0)
        input_key1, input_key2, net_key = random.split(rng, 3)

        x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 2, 5))
        x1 = (x1_1, (x1_2, x1_3))

        if same_inputs:
            x2 = None
        else:
            x2_1, x2_2, x2_3 = random.normal(input_key2, (3, 4, 5))
            x2 = (x2_1, (x2_2, x2_3))

        def net(N_out):
            return stax.parallel(
                stax.Dense(N_out),
                stax.parallel(stax.Dense(N_out + 1), stax.Dense(N_out + 2)))

        # Check NNGP.
        init_fn, apply_fn, _ = net(WIDTH)

        nb_kernel_fn = monte_carlo.monte_carlo_kernel_fn(init_fn,
                                                         apply_fn,
                                                         net_key,
                                                         n_samples=4,
                                                         trace_axes=(-1, ))

        kernel_fn = monte_carlo.monte_carlo_kernel_fn(init_fn,
                                                      apply_fn,
                                                      net_key,
                                                      n_samples=4,
                                                      batch_size=batch_size,
                                                      trace_axes=(-1, ))

        self.assertAllClose(kernel_fn(x1, x2, 'nngp'),
                            nb_kernel_fn(x1, x2, 'nngp'))
Пример #3
0
    def test_sample_vs_analytic_nngp(self, batch_size, device_count,
                                     store_on_device):
        utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model(
            1024, 256,
            xla_bridge.get_backend().platform == 'tpu')

        sample = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, 200,
                                                   batch_size, device_count,
                                                   store_on_device)

        ker_empirical = sample(x1, x2, 'nngp')
        ker_analytic = stax_kernel_fn(x1, x2, 'nngp')

        utils.assert_close_matrices(self, ker_analytic, ker_empirical, 2e-2)
Пример #4
0
    def test_monte_carlo_vs_analytic_ntk(self, batch_size, device_count,
                                         store_on_device):
        utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model(
            256, 2,
            xla_bridge.get_backend().platform == 'tpu')

        sample = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, 100,
                                                   batch_size, device_count,
                                                   store_on_device)

        ker_empirical = sample(x1, x2, 'ntk')
        ker_empirical = (np.sum(ker_empirical, axis=(-1, -2)) /
                         ker_empirical.shape[-1])

        ker_analytic = stax_kernel_fn(x1, x2, 'ntk')

        utils.assert_close_matrices(self, ker_analytic, ker_empirical, 2e-2)
Пример #5
0
    def test_monte_carlo_vs_analytic_ntk(self, batch_size, device_count,
                                         store_on_device):
        test_utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model(
            WIDTH, 2,
            xla_bridge.get_backend().platform == 'tpu')

        sample = monte_carlo.monte_carlo_kernel_fn(init_fn,
                                                   apply_fn,
                                                   key,
                                                   100,
                                                   batch_size,
                                                   device_count,
                                                   store_on_device,
                                                   vmap_axes=0)

        ker_empirical = sample(x1, x2, 'ntk')
        ker_analytic = stax_kernel_fn(x1, x2, 'ntk')

        test_utils.assert_close_matrices(self, ker_analytic, ker_empirical,
                                         2e-2)
Пример #6
0
  def test_sparse_inputs(self, act, kernel):
    key = random.PRNGKey(1)

    input_count = 4
    sparse_count = 2
    input_size = 128
    width = 4096

    # NOTE(schsam): It seems that convergence is slower when inputs are sparse.
    samples = N_SAMPLES

    if xla_bridge.get_backend().platform == 'gpu':
      jtu._default_tolerance[np.onp.dtype(np.onp.float64)] = 5e-4
      samples = 100 * N_SAMPLES
    else:
      jtu._default_tolerance[np.onp.dtype(np.onp.float32)] = 5e-2
      jtu._default_tolerance[np.onp.dtype(np.onp.float64)] = 5e-3

    # a batch of dense inputs
    x_dense = random.normal(key, (input_count, input_size))
    x_sparse = ops.index_update(x_dense, ops.index[:sparse_count, :], 0.)

    activation = stax.Relu() if act == 'relu' else stax.Erf()

    init_fn, apply_fn, kernel_fn = stax.serial(
        stax.Dense(width),
        activation,
        stax.Dense(1 if kernel == 'ntk' else width))
    exact = kernel_fn(x_sparse, None, kernel)
    mc = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn,
                                           random.split(key, 2)[0],
                                           samples)(x_sparse, None, kernel)
    mc = np.reshape(mc, exact.shape)

    assert not np.any(np.isnan(exact))
    self.assertAllClose(exact[sparse_count:, sparse_count:],
                        mc[sparse_count:, sparse_count:], True)
Пример #7
0
 def _get_empirical(n_samples, get):
   kernel_fn_empirical = monte_carlo.monte_carlo_kernel_fn(
       init_fn, apply_fn, key, n_samples)
   return kernel_fn_empirical(x1, x2, get)
Пример #8
0
    def test_monte_carlo_generator(self, batch_size, device_count,
                                   store_on_device, get):
        utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model(
            8, 1)
        x3, x4, _, _, _, _ = _get_inputs_and_model(8, 1)

        log_n_max = 4
        n_samples = [2**k for k in range(log_n_max)]
        sample_generator = monte_carlo.monte_carlo_kernel_fn(
            init_fn, apply_fn, key, n_samples, batch_size, device_count,
            store_on_device)

        if get is None:
            samples_12 = sample_generator(x1, x2)
            samples_34 = sample_generator(x3, x4)

            count = 0
            for n, s_12, s_34 in zip(n_samples, samples_12, samples_34):
                sample_fn = monte_carlo.monte_carlo_kernel_fn(
                    init_fn, apply_fn, key, n, batch_size, device_count,
                    store_on_device)
                sample_12 = sample_fn(x1, x2)
                sample_34 = sample_fn(x3, x4)
                self.assertAllClose(s_12, sample_12, True)
                self.assertAllClose(s_12, s_34, True)
                self.assertAllClose(s_12, sample_34, True)
                count += 1

            self.assertEqual(log_n_max, count)

            ker_analytic_12 = stax_kernel_fn(x1, x2)
            ker_analytic_34 = stax_kernel_fn(x3, x4)

        else:
            samples_12 = sample_generator(x1, x2, get)
            samples_34 = sample_generator(x3, x4, get)

            count = 0
            for n, s_12, s_34 in zip(n_samples, samples_12, samples_34):
                sample_fn = monte_carlo.monte_carlo_kernel_fn(
                    init_fn, apply_fn, key, n, batch_size, device_count,
                    store_on_device)
                sample_12 = sample_fn(x1, x2, get)
                sample_34 = sample_fn(x3, x4, get)
                self.assertAllClose(s_12, sample_12, True)
                self.assertAllClose(s_12, s_34, True)
                self.assertAllClose(s_12, sample_34, True)
                count += 1

            self.assertEqual(log_n_max, count)

            ker_analytic_12 = stax_kernel_fn(x1, x2, get)
            ker_analytic_34 = stax_kernel_fn(x3, x4, get)

        if get == 'ntk':
            s_12 = np.squeeze(s_12, (-1, -2))
        elif get is None or 'ntk' in get:
            s_12 = s_12._replace(ntk=np.squeeze(s_12.ntk, (-1, -2)))

        self.assertAllClose(ker_analytic_12, s_12, True, 2., 2.)
        self.assertAllClose(ker_analytic_12, ker_analytic_34, True)
Пример #9
0
 def _get_empirical(n_samples, get):
     kernel_fn_empirical = monte_carlo.monte_carlo_kernel_fn(
         init_fn, apply_fn, key, n_samples)
     if same_inputs:
         assert (x2 is None)
     return kernel_fn_empirical(x1, x2, get)
Пример #10
0
  def test_fan_in_conv(self,
                       same_inputs,
                       axis,
                       n_branches,
                       get,
                       branch_in,
                       readout):
    if xla_bridge.get_backend().platform == 'cpu':
      raise jtu.SkipTest('Not running CNNs on CPU to save time.')

    if axis in (None, 0, 1, 2) and branch_in == 'dense_after_branch_in':
      raise jtu.SkipTest('`FanInSum` and `FanInConcat(0/1/2)` '
                         'require `is_gaussian`.')

    if axis == 3 and branch_in == 'dense_before_branch_in':
      raise jtu.SkipTest('`FanInConcat` on feature axis requires a dense layer '
                         'after concatenation.')

    key = random.PRNGKey(1)
    X0_1 = random.normal(key, (2, 5, 6, 3))
    X0_2 = None if same_inputs else random.normal(key, (3, 5, 6, 3))

    if xla_bridge.get_backend().platform == 'tpu':
      width = 2048
      n_samples = 1024
      tol = 0.02
    else:
      width = 1024
      n_samples = 512
      tol = 0.01

    conv = stax.Conv(out_chan=width,
                     filter_shape=(3, 3),
                     padding='SAME',
                     W_std=1.25,
                     b_std=0.1)

    input_layers = [conv,
                    stax.FanOut(n_branches)]

    branches = []
    for b in range(n_branches):
      branch_layers = [FanInTest._get_phi(b)]
      for i in range(b):
        branch_layers += [
            stax.Conv(
                out_chan=width,
                filter_shape=(i + 1, 4 - i),
                padding='SAME',
                W_std=1.25 + i,
                b_std=0.1 + i),
            FanInTest._get_phi(i)]

      if branch_in == 'dense_before_branch_in':
        branch_layers += [conv]
      branches += [stax.serial(*branch_layers)]

    output_layers = [
        stax.FanInSum() if axis is None else stax.FanInConcat(axis),
        stax.Relu(),
        stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten()
    ]
    if branch_in == 'dense_after_branch_in':
      output_layers.insert(1, conv)

    nn = stax.serial(*(input_layers + [stax.parallel(*branches)] +
                       output_layers))

    init_fn, apply_fn, kernel_fn = stax.serial(
        nn, stax.Dense(1 if get == 'ntk' else width, 1.25, 0.5))

    kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn(
        init_fn,
        apply_fn,
        key,
        n_samples,
        device_count=0 if axis in (0, -4) else -1)

    exact = kernel_fn(X0_1, X0_2, get=get)
    empirical = kernel_fn_mc(X0_1, X0_2, get=get)
    empirical = empirical.reshape(exact.shape)
    utils.assert_close_matrices(self, empirical, exact, tol)
Пример #11
0
  def test_fan_in_fc(self, same_inputs, axis, n_branches, get, branch_in):
    if axis in (None, 0) and branch_in == 'dense_after_branch_in':
      raise jtu.SkipTest('`FanInSum` and `FanInConcat(0)` '
                         'require `is_gaussian`.')

    if axis == 1 and branch_in == 'dense_before_branch_in':
      raise jtu.SkipTest('`FanInConcat` on feature axis requires a dense layer'
                         'after concatenation.')

    key = random.PRNGKey(1)
    X0_1 = random.normal(key, (10, 20))
    X0_2 = None if same_inputs else random.normal(key, (8, 20))

    if xla_bridge.get_backend().platform == 'tpu':
      width = 2048
      n_samples = 1024
      tol = 0.02
    else:
      width = 1024
      n_samples = 256
      tol = 0.01

    dense = stax.Dense(width, 1.25, 0.1)
    input_layers = [dense,
                    stax.FanOut(n_branches)]

    branches = []
    for b in range(n_branches):
      branch_layers = [FanInTest._get_phi(b)]
      for i in range(b):
        branch_layers += [
            stax.Dense(width, 1. + 2 * i, 0.5 + i),
            FanInTest._get_phi(i)]

      if branch_in == 'dense_before_branch_in':
        branch_layers += [dense]
      branches += [stax.serial(*branch_layers)]

    output_layers = [
        stax.FanInSum() if axis is None else stax.FanInConcat(axis),
        stax.Relu()
    ]
    if branch_in == 'dense_after_branch_in':
      output_layers.insert(1, dense)

    nn = stax.serial(*(input_layers + [stax.parallel(*branches)] +
                       output_layers))

    if get == 'nngp':
      init_fn, apply_fn, kernel_fn = nn
    elif get == 'ntk':
      init_fn, apply_fn, kernel_fn = stax.serial(nn, stax.Dense(1, 1.25, 0.5))
    else:
      raise ValueError(get)

    kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn(
        init_fn, apply_fn, key, n_samples, device_count=0)

    exact = kernel_fn(X0_1, X0_2, get=get)
    empirical = kernel_fn_mc(X0_1, X0_2, get=get)
    empirical = empirical.reshape(exact.shape)
    utils.assert_close_matrices(self, empirical, exact, tol)