Example #1
0
    def test_parallel_in_out_empirical(self, same_inputs):
        test_utils.stub_out_pmap(batch, 2)
        rng = random.PRNGKey(0)
        input_key1, input_key2, net_key = random.split(rng, 3)

        x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 4, 10))
        x2_1, x2_2, x2_3 = random.normal(input_key2, (3, 8, 10))

        x1 = (x1_1, (x1_2, x1_3))
        x2 = (x2_1, (x2_2, x2_3))

        def net(N_out):
            return stax.parallel(
                stax.Dense(N_out),
                stax.parallel(stax.Dense(N_out + 1), stax.Dense(N_out + 2)))

        # Check NNGP.
        init_fn, apply_fn, _ = net(WIDTH)
        _, params = init_fn(net_key, ((-1, 10), ((-1, 10), (-1, 10))))

        kernel_fn = jit(empirical.empirical_nngp_fn(apply_fn))
        batch_kernel_fn = jit(batch.batch(kernel_fn, 2))

        test_utils.assert_close_matrices(self, kernel_fn(x1, x2, params),
                                         batch_kernel_fn(x1, x2, params), RTOL)

        # Check NTK.
        init_fn, apply_fn, _ = stax.serial(net(WIDTH), net(1))
        _, params = init_fn(net_key, ((-1, 10), ((-1, 10), (-1, 10))))

        kernel_fn = jit(empirical.empirical_ntk_fn(apply_fn))
        batch_kernel_fn = jit(batch.batch(kernel_fn, 2))

        test_utils.assert_close_matrices(self, kernel_fn(x1, x2, params),
                                         batch_kernel_fn(x1, x2, params), RTOL)
Example #2
0
    def test_sample_vs_analytic_nngp(self, batch_size, device_count,
                                     store_on_device):
        test_utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model(
            WIDTH, 256,
            xla_bridge.get_backend().platform == 'tpu')

        sample = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, 200,
                                                   batch_size, device_count,
                                                   store_on_device)

        ker_empirical = sample(x1, x2, 'nngp')
        ker_analytic = stax_kernel_fn(x1, x2, 'nngp')

        test_utils.assert_close_matrices(self, ker_analytic, ker_empirical,
                                         2e-2)
Example #3
0
    def test_parallel_in_out(self, same_inputs):
        test_utils.stub_out_pmap(batch, 2)
        rng = random.PRNGKey(0)
        input_key1, input_key2, mc_key = random.split(rng, 3)

        x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 4, 10))
        x2_1, x2_2, x2_3 = random.normal(input_key2, (3, 8, 10))

        x1 = (x1_1, (x1_2, x1_3))
        x2 = (x2_1, (x2_2, x2_3))

        N = WIDTH

        def net(N_out):
            return stax.parallel(
                stax.Dense(N_out),
                stax.parallel(stax.Dense(N_out + 1), stax.Dense(N_out + 2)))

        # Check NNGP.

        readin = net(N)
        readout = net(1)

        K_readin_fn = jit(readin[2])
        K_readout_fn = jit(partial(readout[2], get='nngp'))

        batch_K_readin_fn = batch.batch(K_readin_fn, 2)
        batch_K_readout_fn = batch.batch(K_readout_fn, 2)

        test_utils.assert_close_matrices(
            self, K_readout_fn(K_readin_fn(x1, x2)),
            batch_K_readout_fn(batch_K_readin_fn(x1, x2)), RTOL)

        # Check Both.
        K_readin_fn = jit(readin[2])
        K_readout_fn = jit(partial(readout[2], get=('nngp', 'ntk')))

        batch_K_readin_fn = batch.batch(K_readin_fn, 2)
        batch_K_readout_fn = batch.batch(K_readout_fn, 2)

        get_ntk = utils.nt_tree_fn()(lambda k: k.ntk)

        test_utils.assert_close_matrices(
            self, get_ntk(K_readout_fn(K_readin_fn(x1, x2))),
            get_ntk(batch_K_readout_fn(batch_K_readin_fn(x1, x2))), RTOL)
Example #4
0
    def test_monte_carlo_vs_analytic_ntk(self, batch_size, device_count,
                                         store_on_device):
        test_utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model(
            256, 2,
            xla_bridge.get_backend().platform == 'tpu')

        sample = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, 100,
                                                   batch_size, device_count,
                                                   store_on_device)

        ker_empirical = sample(x1, x2, 'ntk')
        ker_empirical = (np.sum(ker_empirical, axis=(-1, -2)) /
                         ker_empirical.shape[-1])

        ker_analytic = stax_kernel_fn(x1, x2, 'ntk')

        test_utils.assert_close_matrices(self, ker_analytic, ker_empirical,
                                         2e-2)
Example #5
0
    def _check_agreement_with_empirical(self, net, same_inputs, is_conv,
                                        use_dropout, is_ntk, proj_into_2d):

        (init_fn, apply_fn, kernel_fn), input_shape = net

        num_samples = N_SAMPLES * 5 if use_dropout else N_SAMPLES
        key = random.PRNGKey(1)
        x1, x2 = _get_inputs(key, is_conv, same_inputs, input_shape)

        x1_out_shape, params = init_fn(key, x1.shape)
        if same_inputs:
            assert (x2 is None)
        if x2 is None:
            x2_out_shape = x1_out_shape
        else:
            x2_out_shape, params = init_fn(key, x2.shape)
        del (params)

        def _get_empirical(n_samples, get):
            kernel_fn_empirical = monte_carlo.monte_carlo_kernel_fn(
                init_fn, apply_fn, key, n_samples)
            if same_inputs:
                assert (x2 is None)
            return kernel_fn_empirical(x1, x2, get)

        if proj_into_2d == 'ATTN_PARAM':
            # no analytic kernel available, just test forward/backward pass
            _get_empirical(1, 'ntk' if is_ntk else 'nngp')
        else:
            if is_ntk:
                exact, shape1, shape2 = kernel_fn(x1, x2,
                                                  ('ntk', 'shape1', 'shape2'))
                empirical = np.reshape(_get_empirical(num_samples, 'ntk'),
                                       exact.shape)
            else:
                exact, shape1, shape2 = kernel_fn(x1, x2,
                                                  ('nngp', 'shape1', 'shape2'))
                empirical = _get_empirical(num_samples, 'nngp')
            test_utils.assert_close_matrices(self, exact, empirical, RTOL)
            self.assertEqual(shape1, x1_out_shape)
            self.assertEqual(shape2, x2_out_shape)
Example #6
0
    def test_fan_in_conv(self, same_inputs, axis, n_branches, get, branch_in,
                         readout):
        if xla_bridge.get_backend().platform == 'cpu':
            raise jtu.SkipTest('Not running CNNs on CPU to save time.')

        if axis in (None, 0, 1, 2) and branch_in == 'dense_after_branch_in':
            raise jtu.SkipTest('`FanInSum` and `FanInConcat(0/1/2)` '
                               'require `is_gaussian`.')

        if axis == 3 and branch_in == 'dense_before_branch_in':
            raise jtu.SkipTest(
                '`FanInConcat` on feature axis requires a dense layer '
                'after concatenation.')

        key = random.PRNGKey(1)
        X0_1 = random.normal(key, (2, 5, 6, 3))
        X0_2 = None if same_inputs else random.normal(key, (3, 5, 6, 3))

        if xla_bridge.get_backend().platform == 'tpu':
            width = 2048
            n_samples = 1024
            tol = 0.02
        else:
            width = 1024
            n_samples = 512
            tol = 0.01

        conv = stax.Conv(out_chan=width,
                         filter_shape=(3, 3),
                         padding='SAME',
                         W_std=1.25,
                         b_std=0.1)

        input_layers = [conv, stax.FanOut(n_branches)]

        branches = []
        for b in range(n_branches):
            branch_layers = [FanInTest._get_phi(b)]
            for i in range(b):
                branch_layers += [
                    stax.Conv(out_chan=width,
                              filter_shape=(i + 1, 4 - i),
                              padding='SAME',
                              W_std=1.25 + i,
                              b_std=0.1 + i),
                    FanInTest._get_phi(i)
                ]

            if branch_in == 'dense_before_branch_in':
                branch_layers += [conv]
            branches += [stax.serial(*branch_layers)]

        output_layers = [
            stax.FanInSum() if axis is None else stax.FanInConcat(axis),
            stax.Relu(),
            stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten()
        ]
        if branch_in == 'dense_after_branch_in':
            output_layers.insert(1, conv)

        nn = stax.serial(*(input_layers + [stax.parallel(*branches)] +
                           output_layers))

        init_fn, apply_fn, kernel_fn = stax.serial(
            nn, stax.Dense(1 if get == 'ntk' else width, 1.25, 0.5))

        kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn(
            init_fn,
            apply_fn,
            key,
            n_samples,
            device_count=0 if axis in (0, -4) else -1)

        exact = kernel_fn(X0_1, X0_2, get=get)
        empirical = kernel_fn_mc(X0_1, X0_2, get=get)
        empirical = empirical.reshape(exact.shape)
        test_utils.assert_close_matrices(self, empirical, exact, tol)
Example #7
0
    def test_fan_in_fc(self, same_inputs, axis, n_branches, get, branch_in):
        if axis in (None, 0) and branch_in == 'dense_after_branch_in':
            raise jtu.SkipTest('`FanInSum` and `FanInConcat(0)` '
                               'require `is_gaussian`.')

        if axis == 1 and branch_in == 'dense_before_branch_in':
            raise jtu.SkipTest(
                '`FanInConcat` on feature axis requires a dense layer'
                'after concatenation.')

        key = random.PRNGKey(1)
        X0_1 = random.normal(key, (10, 20))
        X0_2 = None if same_inputs else random.normal(key, (8, 20))

        if xla_bridge.get_backend().platform == 'tpu':
            width = 2048
            n_samples = 1024
            tol = 0.02
        else:
            width = 1024
            n_samples = 256
            tol = 0.01

        dense = stax.Dense(width, 1.25, 0.1)
        input_layers = [dense, stax.FanOut(n_branches)]

        branches = []
        for b in range(n_branches):
            branch_layers = [FanInTest._get_phi(b)]
            for i in range(b):
                branch_layers += [
                    stax.Dense(width, 1. + 2 * i, 0.5 + i),
                    FanInTest._get_phi(i)
                ]

            if branch_in == 'dense_before_branch_in':
                branch_layers += [dense]
            branches += [stax.serial(*branch_layers)]

        output_layers = [
            stax.FanInSum() if axis is None else stax.FanInConcat(axis),
            stax.Relu()
        ]
        if branch_in == 'dense_after_branch_in':
            output_layers.insert(1, dense)

        nn = stax.serial(*(input_layers + [stax.parallel(*branches)] +
                           output_layers))

        if get == 'nngp':
            init_fn, apply_fn, kernel_fn = nn
        elif get == 'ntk':
            init_fn, apply_fn, kernel_fn = stax.serial(
                nn, stax.Dense(1, 1.25, 0.5))
        else:
            raise ValueError(get)

        kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn(init_fn,
                                                         apply_fn,
                                                         key,
                                                         n_samples,
                                                         device_count=0)

        exact = kernel_fn(X0_1, X0_2, get=get)
        empirical = kernel_fn_mc(X0_1, X0_2, get=get)
        empirical = empirical.reshape(exact.shape)
        test_utils.assert_close_matrices(self, empirical, exact, tol)