示例#1
0
    def testMaxLearningRate(self, train_shape, network, out_logits,
                            fn_and_kernel):

        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)

        keys = tf_random_split(key)
        key = keys[0]
        split = keys[1]
        if len(train_shape) == 2:
            train_shape = (train_shape[0] * 5, train_shape[1] * 10)
        else:
            train_shape = (16, 8, 8, 3)
        x_train = np.asarray(normal(train_shape, seed=split))

        keys = tf_random_split(key)
        key = keys[0]
        split = keys[1]
        y_train = np.asarray(
            stateless_uniform(shape=(train_shape[0], out_logits),
                              seed=split,
                              minval=0,
                              maxval=1) < 0.5, np.float32)
        # Regress to an MSE loss.
        loss = lambda params, x: 0.5 * np.mean((f(params, x) - y_train)**2)
        grad_loss = jit(grad(loss))

        def get_loss(opt_state):
            return loss(get_params(opt_state), x_train)

        steps = 20

        for lr_factor in [0.5, 3.]:
            params, f, ntk = fn_and_kernel(key, train_shape[1:], network,
                                           out_logits)
            g_dd = ntk(x_train, None, 'ntk')

            step_size = predict.max_learning_rate(
                g_dd, y_train_size=y_train.size) * lr_factor
            opt_init, opt_update, get_params = optimizers.sgd(step_size)
            opt_state = opt_init(params)

            init_loss = get_loss(opt_state)

            for i in range(steps):
                params = get_params(opt_state)
                opt_state = opt_update(i, grad_loss(params, x_train),
                                       opt_state)

            trained_loss = get_loss(opt_state)
            loss_ratio = trained_loss / (init_loss + 1e-12)
            if lr_factor == 3.:
                if not math.isnan(loss_ratio):
                    self.assertGreater(loss_ratio, 10.)
            else:
                self.assertLess(loss_ratio, 0.1)
    def testAutomatic(self, train_shape, test_shape, network, name, kernel_fn,
                      batch_size):
        test_utils.stub_out_pmap(batch, 2)

        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        keys = tf_random_split(key, 3)
        key = keys[0]
        self_split = keys[1]
        other_split = keys[2]
        data_self = np.asarray(normal(train_shape, seed=self_split))
        data_other = np.asarray(normal(test_shape, seed=other_split))

        kernel_fn = kernel_fn(key, train_shape[1:], network)

        kernel_batched = batch.batch(kernel_fn, batch_size=batch_size)
        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)

        kernel_batched = batch.batch(kernel_fn,
                                     batch_size=batch_size,
                                     store_on_device=False)
        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)
    def testNTKAgainstDirect(self, train_shape, test_shape, network, name,
                             kernel_fn):
        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        splits = tf_random_split(seed=tf.convert_to_tensor(key,
                                                           dtype=tf.int32),
                                 num=3)
        key = splits[0]
        self_split = splits[1]
        other_split = splits[2]
        data_self = np.asarray(normal(train_shape, seed=self_split))
        data_other = np.asarray(normal(test_shape, seed=other_split))

        implicit, direct, _ = kernel_fn(key,
                                        train_shape[1:],
                                        network,
                                        diagonal_axes=(),
                                        trace_axes=())

        g = implicit(data_self, None)
        g_direct = direct(data_self, None)
        self.assertAllClose(g, g_direct)

        g = implicit(data_other, data_self)
        g_direct = direct(data_other, data_self)
        self.assertAllClose(g, g_direct)
    def _test_analytic_kernel_composition(self, batching_fn):
        # Check Fully-Connected.
        rng = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        keys = tf_random_split(rng)
        rng_self = keys[0]
        rng_other = keys[1]
        x_self = np.asarray(normal((8, 10), seed=rng_self))
        x_other = np.asarray(normal((2, 10), seed=rng_other))
        Block = stax.serial(stax.Dense(256), stax.Relu())

        _, _, ker_fn = Block
        ker_fn = batching_fn(ker_fn)

        _, _, composed_ker_fn = stax.serial(Block, Block)

        ker_out = ker_fn(ker_fn(x_self))
        composed_ker_out = composed_ker_fn(x_self)
        if batching_fn == batch._parallel:
            # In the parallel setting, `x1_is_x2` is not computed correctly
            # when x1==x2.
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)

        ker_out = ker_fn(ker_fn(x_self, x_other))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)

        # Check convolutional + pooling.
        x_self = np.asarray(normal((8, 10, 10, 3), seed=rng))
        x_other = np.asarray(normal((2, 10, 10, 3), seed=rng))

        Block = stax.serial(stax.Conv(256, (2, 2)), stax.Relu())
        Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10))

        block_ker_fn, readout_ker_fn = Block[2], Readout[2]
        _, _, composed_ker_fn = stax.serial(Block, Readout)
        block_ker_fn = batching_fn(block_ker_fn)
        readout_ker_fn = batching_fn(readout_ker_fn)

        ker_out = readout_ker_fn(block_ker_fn(x_self))
        composed_ker_out = composed_ker_fn(x_self)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)
        ker_out = readout_ker_fn(block_ker_fn(x_self, x_other))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)
def _empirical_kernel(key, input_shape, network, out_logits, use_dropout):
    init_fn, f, _ = _build_network(input_shape, network, out_logits,
                                   use_dropout)
    keys = tf_random_split(key)
    key = keys[0]
    split = keys[1]
    _, params = init_fn(key, (1, ) + input_shape)
    kernel_fn = jit(empirical.empirical_ntk_fn(f))
    return partial(kernel_fn, params=params, keys=split)
    def testLinearization(self, shape):
        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        splits = tf_random_split(seed=tf.convert_to_tensor(key,
                                                           dtype=tf.int32),
                                 num=4)
        key = splits[0]
        s1 = splits[1]
        s2 = splits[2]
        s3 = splits[3]
        w1 = np.asarray(normal(shape, seed=s1))
        w1 = 0.5 * (w1 + w1.T)
        w2 = np.asarray(normal(shape, seed=s2))
        b = np.asarray(normal((shape[-1], ), seed=s3))
        params = (w1, w2, b)

        splits = tf_random_split(seed=tf.convert_to_tensor(key,
                                                           dtype=tf.int32),
                                 num=2)
        key = splits[0]
        split = splits[1]
        x0 = np.asarray(normal((shape[-1], ), seed=split))

        f_lin = empirical.linearize(EmpiricalTest.f, x0)

        for _ in range(TAYLOR_RANDOM_SAMPLES):
            for do_alter in [True, False]:
                for do_shift_x in [True, False]:
                    splits = tf_random_split(seed=tf.convert_to_tensor(
                        key, dtype=tf.int32),
                                             num=2)
                    key = splits[0]
                    split = splits[1]
                    x = np.asarray(normal((shape[-1], ), seed=split))
                    self.assertAllClose(
                        EmpiricalTest.f_lin_exact(x0,
                                                  x,
                                                  params,
                                                  do_alter,
                                                  do_shift_x=do_shift_x),
                        f_lin(x, params, do_alter, do_shift_x=do_shift_x))
示例#7
0
        def predict_mc(count, key):
            key = tf_random_split(key, count)
            fx_train, fx_test = vmap(predict_empirical)(key)
            fx_train_mean = np.mean(fx_train, axis=0)
            fx_test_mean = np.mean(fx_test, axis=0)

            fx_train_centered = fx_train - fx_train_mean
            fx_test_centered = fx_test - fx_test_mean

            cov_train = PredictTest._cov_empirical(fx_train_centered)
            cov_test = PredictTest._cov_empirical(fx_test_centered)

            return fx_train_mean, fx_test_mean, cov_train, cov_test
示例#8
0
 def _get_inputs(cls, out_logits, test_shape, train_shape):
     key = stateless_uniform(shape=[2],
                             seed=[0, 0],
                             minval=None,
                             maxval=None,
                             dtype=tf.int32)
     keys = tf_random_split(key)
     key = keys[0]
     split = keys[1]
     x_train = np.asarray(normal(train_shape, seed=split))
     keys = tf_random_split(key)
     key = keys[0]
     split = keys[1]
     y_train = np.asarray(
         stateless_uniform(shape=(train_shape[0], out_logits),
                           seed=split,
                           minval=0,
                           maxval=1) < 0.5, np.float32)
     keys = tf_random_split(key)
     key = keys[0]
     split = keys[1]
     x_test = np.asarray(normal(test_shape, seed=split))
     return key, x_test, x_train, y_train
    def testAxes(self, diagonal_axes, trace_axes):
        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        splits = tf_random_split(seed=tf.convert_to_tensor(key,
                                                           dtype=tf.int32),
                                 num=3)
        key = splits[0]
        self_split = splits[1]
        other_split = splits[2]
        data_self = np.asarray(normal((4, 5, 6, 3), seed=self_split))
        data_other = np.asarray(normal((2, 5, 6, 3), seed=other_split))

        _diagonal_axes = utils.canonicalize_axis(diagonal_axes, data_self)
        _trace_axes = utils.canonicalize_axis(trace_axes, data_self)

        if any(d == c for d in _diagonal_axes for c in _trace_axes):
            raise absltest.SkipTest(
                'diagonal axes must be different from channel axes.')

        implicit, direct, nngp = KERNELS['empirical_logits_3'](
            key, (5, 6, 3),
            CONV,
            diagonal_axes=diagonal_axes,
            trace_axes=trace_axes)

        n_marg = len(_diagonal_axes)
        n_chan = len(_trace_axes)

        g = implicit(data_self, None)
        g_direct = direct(data_self, None)
        g_nngp = nngp(data_self, None)

        self.assertAllClose(g, g_direct)
        self.assertEqual(g_nngp.shape, g.shape)
        self.assertEqual(2 * (data_self.ndim - n_chan) - n_marg, g_nngp.ndim)

        if 0 not in _trace_axes and 0 not in _diagonal_axes:
            g = implicit(data_other, data_self)
            g_direct = direct(data_other, data_self)
            g_nngp = nngp(data_other, data_self)

            self.assertAllClose(g, g_direct)
            self.assertEqual(g_nngp.shape, g.shape)
            self.assertEqual(2 * (data_self.ndim - n_chan) - n_marg,
                             g_nngp.ndim)
示例#10
0
    def testSerial(self, train_shape, test_shape, network, name, kernel_fn,
                   batch_size):
        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        keys = tf_random_split(key, 3)
        key = keys[0]
        self_split = keys[1]
        other_split = keys[2]
        data_self = np.asarray(normal(train_shape, seed=self_split))
        data_other = np.asarray(normal(test_shape, seed=other_split))
        kernel_fn = kernel_fn(key, train_shape[1:], network)
        kernel_batched = batch._serial(kernel_fn, batch_size=batch_size)

        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)
示例#11
0
    def testParallel(self, train_shape, test_shape, network, name, kernel_fn):
        test_utils.stub_out_pmap(batch, 2)
        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        keys = tf_random_split(key, 3)
        key = keys[0]
        self_split = keys[1]
        other_split = keys[2]
        data_self = np.asarray(normal(train_shape, seed=self_split))
        data_other = np.asarray(normal(test_shape, seed=other_split))

        kernel_fn = kernel_fn(key, train_shape[1:], network, use_dropout=False)
        kernel_batched = batch._parallel(kernel_fn)

        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other, True)
示例#12
0
def _get_inputs_and_model(width=1, n_classes=2, use_conv=True):
    key = stateless_uniform(shape=[2],
                            seed=[1, 1],
                            minval=None,
                            maxval=None,
                            dtype=tf.int32)
    keys = tf_random_split(key)
    key = keys[0]
    split = keys[1]
    x1 = np.asarray(normal((8, 4, 3, 2), seed=key))
    x2 = np.asarray(normal((4, 4, 3, 2), seed=split))

    if not use_conv:
        x1 = np.reshape(x1, (x1.shape[0], -1))
        x2 = np.reshape(x2, (x2.shape[0], -1))

    init_fn, apply_fn, kernel_fn = stax.serial(
        stax.Conv(width, (3, 3)) if use_conv else stax.Dense(width),
        stax.Relu(), stax.Flatten(), stax.Dense(n_classes, 2., 0.5))
    return x1, x2, init_fn, apply_fn, kernel_fn, key
示例#13
0
    def testTrainedEnsemblePredCov(self, train_shape, test_shape, network,
                                   out_logits):
        training_steps = 1000
        learning_rate = 0.1
        ensemble_size = 1024

        init_fn, apply_fn, kernel_fn = stax.serial(
            stax.Dense(128, W_std=1.2, b_std=0.05), stax.Erf(),
            stax.Dense(out_logits, W_std=1.2, b_std=0.05))

        opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
        opt_update = jit(opt_update)

        key, x_test, x_train, y_train = self._get_inputs(
            out_logits, test_shape, train_shape)
        predict_fn_mse_ens = predict.gradient_descent_mse_ensemble(
            kernel_fn,
            x_train,
            y_train,
            learning_rate=learning_rate,
            diag_reg=0.)

        train = (x_train, y_train)
        ensemble_key = tf_random_split(key, ensemble_size)

        loss = jit(lambda params, x, y: 0.5 * np.mean(
            (apply_fn(params, x) - y)**2))
        grad_loss = jit(lambda state, x, y: grad(loss)
                        (get_params(state), x, y))

        def train_network(key):
            _, params = init_fn(key, (-1, ) + train_shape[1:])
            opt_state = opt_init(params)
            for i in range(training_steps):
                opt_state = opt_update(i, grad_loss(opt_state, *train),
                                       opt_state)

            return get_params(opt_state)

        params = vmap(train_network)(ensemble_key)
        rtol = 0.08

        for x in [None, 'x_test']:
            with self.subTest(x=x):
                x = x if x is None else x_test
                x_fin = x_train if x is None else x_test
                ensemble_fx = vmap(apply_fn, (0, None))(params, x_fin)

                mean_emp = np.mean(ensemble_fx, axis=0)
                mean_subtracted = ensemble_fx - mean_emp
                cov_emp = np.einsum(
                    'ijk,ilk->jl',
                    mean_subtracted,
                    mean_subtracted,
                    optimize=True) / (mean_subtracted.shape[0] *
                                      mean_subtracted.shape[-1])

                ntk = predict_fn_mse_ens(training_steps,
                                         x,
                                         'ntk',
                                         compute_cov=True)
                self._assertAllClose(mean_emp, ntk.mean, rtol)
                self._assertAllClose(cov_emp, ntk.covariance, rtol)
    def testTaylorExpansion(self, shape):
        def f_2_exact(x0, x, params, do_alter, do_shift_x=True):
            w1, w2, b = params
            f_lin = EmpiricalTest.f_lin_exact(x0, x, params, do_alter,
                                              do_shift_x)
            if do_shift_x:
                x0 = x0 * 2 + 1.
                x = x * 2 + 1.
            if do_alter:
                b *= 2.
                w1 += 5.
                w2 /= 0.9
            dx = x - x0
            return f_lin + 0.5 * np.dot(np.dot(dx.T, w1), dx)

        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        splits = tf_random_split(seed=tf.convert_to_tensor(key,
                                                           dtype=tf.int32),
                                 num=4)
        key = splits[0]
        s1 = splits[1]
        s2 = splits[2]
        s3 = splits[3]
        w1 = np.asarray(normal(shape, seed=s1))
        w1 = 0.5 * (w1 + w1.T)
        w2 = np.asarray(normal(shape, seed=s2))
        b = np.asarray(normal((shape[-1], ), seed=s3))
        params = (w1, w2, b)

        splits = tf_random_split(seed=tf.convert_to_tensor(key,
                                                           dtype=tf.int32),
                                 num=2)
        key = splits[0]
        split = splits[1]
        x0 = np.asarray(normal((shape[-1], ), seed=split))

        f_lin = empirical.taylor_expand(EmpiricalTest.f, x0, 1)
        f_2 = empirical.taylor_expand(EmpiricalTest.f, x0, 2)

        for _ in range(TAYLOR_RANDOM_SAMPLES):
            for do_alter in [True, False]:
                for do_shift_x in [True, False]:
                    splits = tf_random_split(seed=tf.convert_to_tensor(
                        key, dtype=tf.int32),
                                             num=2)
                    key = splits[0]
                    split = splits[1]
                    x = np.asarray(normal((shape[-1], ), seed=split))
                    self.assertAllClose(
                        EmpiricalTest.f_lin_exact(x0,
                                                  x,
                                                  params,
                                                  do_alter,
                                                  do_shift_x=do_shift_x),
                        f_lin(x, params, do_alter, do_shift_x=do_shift_x))
                    self.assertAllClose(
                        f_2_exact(x0,
                                  x,
                                  params,
                                  do_alter,
                                  do_shift_x=do_shift_x),
                        f_2(x, params, do_alter, do_shift_x=do_shift_x))