Пример #1
0
    def testMaxLearningRate(self, train_shape, network, out_logits,
                            fn_and_kernel):

        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)

        keys = tf_random_split(key)
        key = keys[0]
        split = keys[1]
        if len(train_shape) == 2:
            train_shape = (train_shape[0] * 5, train_shape[1] * 10)
        else:
            train_shape = (16, 8, 8, 3)
        x_train = np.asarray(normal(train_shape, seed=split))

        keys = tf_random_split(key)
        key = keys[0]
        split = keys[1]
        y_train = np.asarray(
            stateless_uniform(shape=(train_shape[0], out_logits),
                              seed=split,
                              minval=0,
                              maxval=1) < 0.5, np.float32)
        # Regress to an MSE loss.
        loss = lambda params, x: 0.5 * np.mean((f(params, x) - y_train)**2)
        grad_loss = jit(grad(loss))

        def get_loss(opt_state):
            return loss(get_params(opt_state), x_train)

        steps = 20

        for lr_factor in [0.5, 3.]:
            params, f, ntk = fn_and_kernel(key, train_shape[1:], network,
                                           out_logits)
            g_dd = ntk(x_train, None, 'ntk')

            step_size = predict.max_learning_rate(
                g_dd, y_train_size=y_train.size) * lr_factor
            opt_init, opt_update, get_params = optimizers.sgd(step_size)
            opt_state = opt_init(params)

            init_loss = get_loss(opt_state)

            for i in range(steps):
                params = get_params(opt_state)
                opt_state = opt_update(i, grad_loss(params, x_train),
                                       opt_state)

            trained_loss = get_loss(opt_state)
            loss_ratio = trained_loss / (init_loss + 1e-12)
            if lr_factor == 3.:
                if not math.isnan(loss_ratio):
                    self.assertGreater(loss_ratio, 10.)
            else:
                self.assertLess(loss_ratio, 0.1)
Пример #2
0
    def testPredictOnCPU(self):
        key1 = stateless_uniform(shape=[2],
                                 seed=[1, 1],
                                 minval=None,
                                 maxval=None,
                                 dtype=tf.int32)
        key2 = stateless_uniform(shape=[2],
                                 seed=[1, 1],
                                 minval=None,
                                 maxval=None,
                                 dtype=tf.int32)
        key3 = stateless_uniform(shape=[2],
                                 seed=[1, 1],
                                 minval=None,
                                 maxval=None,
                                 dtype=tf.int32)
        x_train = np.asarray(normal((4, 4, 4, 2), seed=key1))
        x_test = np.asarray(normal((8, 4, 4, 2), seed=key2))

        y_train = np.asarray(stateless_uniform(shape=(4, 2), seed=key3))

        _, _, kernel_fn = stax.serial(stax.Conv(1, (3, 3)), stax.Relu(),
                                      stax.Flatten(), stax.Dense(1))

        for store_on_device in [False, True]:
            for device_count in [0, 1]:
                for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]:
                    for x in [None, 'x_test']:
                        with self.subTest(store_on_device=store_on_device,
                                          device_count=device_count,
                                          get=get,
                                          x=x):
                            kernel_fn_batched = batch.batch(
                                kernel_fn, 2, device_count, store_on_device)
                            predictor = predict.gradient_descent_mse_ensemble(
                                kernel_fn_batched, x_train, y_train)

                            x = x if x is None else x_test
                            predict_none = predictor(None,
                                                     x,
                                                     get,
                                                     compute_cov=True)
                            predict_inf = predictor(np.inf,
                                                    x,
                                                    get,
                                                    compute_cov=True)
                            self.assertAllClose(predict_none, predict_inf)

                            if x is not None:
                                on_cpu = (not store_on_device
                                          or xla_bridge.get_backend().platform
                                          == 'cpu')
                                self.assertEqual(on_cpu,
                                                 utils.is_on_cpu(predict_inf))
                                self.assertEqual(on_cpu,
                                                 utils.is_on_cpu(predict_none))
Пример #3
0
 def init_fun(rng, input_shape):
   output_shape = input_shape[:-1] + (out_dim,)
   keys = split(seed=tf.convert_to_tensor(rng, dtype=tf.int32), num=2)
   k1 = keys[0]
   k2 = keys[1]
   # convert the two keys from shape (2,) into a scalar
   k1 = stateless_uniform(shape=[], seed=k1, minval=None, maxval=None, dtype=tf.int32)
   k2 = stateless_uniform(shape=[], seed=k2, minval=None, maxval=None, dtype=tf.int32)
   W = W_init(seed=k1, shape=(input_shape[-1], out_dim))
   b = b_init(seed=k2, shape=(out_dim,))
   return tfnp.zeros(output_shape), (W.numpy(), b.numpy())
Пример #4
0
    def testAutomatic(self, train_shape, test_shape, network, name, kernel_fn,
                      batch_size):
        test_utils.stub_out_pmap(batch, 2)

        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        keys = tf_random_split(key, 3)
        key = keys[0]
        self_split = keys[1]
        other_split = keys[2]
        data_self = np.asarray(normal(train_shape, seed=self_split))
        data_other = np.asarray(normal(test_shape, seed=other_split))

        kernel_fn = kernel_fn(key, train_shape[1:], network)

        kernel_batched = batch.batch(kernel_fn, batch_size=batch_size)
        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)

        kernel_batched = batch.batch(kernel_fn,
                                     batch_size=batch_size,
                                     store_on_device=False)
        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)
    def testNTKAgainstDirect(self, train_shape, test_shape, network, name,
                             kernel_fn):
        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        splits = tf_random_split(seed=tf.convert_to_tensor(key,
                                                           dtype=tf.int32),
                                 num=3)
        key = splits[0]
        self_split = splits[1]
        other_split = splits[2]
        data_self = np.asarray(normal(train_shape, seed=self_split))
        data_other = np.asarray(normal(test_shape, seed=other_split))

        implicit, direct, _ = kernel_fn(key,
                                        train_shape[1:],
                                        network,
                                        diagonal_axes=(),
                                        trace_axes=())

        g = implicit(data_self, None)
        g_direct = direct(data_self, None)
        self.assertAllClose(g, g_direct)

        g = implicit(data_other, data_self)
        g_direct = direct(data_other, data_self)
        self.assertAllClose(g, g_direct)
Пример #6
0
    def _test_analytic_kernel_composition(self, batching_fn):
        # Check Fully-Connected.
        rng = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        keys = tf_random_split(rng)
        rng_self = keys[0]
        rng_other = keys[1]
        x_self = np.asarray(normal((8, 10), seed=rng_self))
        x_other = np.asarray(normal((2, 10), seed=rng_other))
        Block = stax.serial(stax.Dense(256), stax.Relu())

        _, _, ker_fn = Block
        ker_fn = batching_fn(ker_fn)

        _, _, composed_ker_fn = stax.serial(Block, Block)

        ker_out = ker_fn(ker_fn(x_self))
        composed_ker_out = composed_ker_fn(x_self)
        if batching_fn == batch._parallel:
            # In the parallel setting, `x1_is_x2` is not computed correctly
            # when x1==x2.
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)

        ker_out = ker_fn(ker_fn(x_self, x_other))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)

        # Check convolutional + pooling.
        x_self = np.asarray(normal((8, 10, 10, 3), seed=rng))
        x_other = np.asarray(normal((2, 10, 10, 3), seed=rng))

        Block = stax.serial(stax.Conv(256, (2, 2)), stax.Relu())
        Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10))

        block_ker_fn, readout_ker_fn = Block[2], Readout[2]
        _, _, composed_ker_fn = stax.serial(Block, Readout)
        block_ker_fn = batching_fn(block_ker_fn)
        readout_ker_fn = batching_fn(readout_ker_fn)

        ker_out = readout_ker_fn(block_ker_fn(x_self))
        composed_ker_out = composed_ker_fn(x_self)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)
        ker_out = readout_ker_fn(block_ker_fn(x_self, x_other))
        composed_ker_out = composed_ker_fn(x_self, x_other)
        if batching_fn == batch._parallel:
            composed_ker_out = composed_ker_out.replace(
                x1_is_x2=ker_out.x1_is_x2)
        self.assertAllClose(ker_out, composed_ker_out)
def main(unused_argv):
    # Build data pipelines.
    print('Loading data.')
    x_train, y_train, x_test, y_test = \
        datasets.get_dataset('mnist', FLAGS.train_size, FLAGS.test_size)

    # Build the network
    init_fn, apply_fn, _ = stax.serial(stax.Dense(512, 1., 0.05), stax.Erf(),
                                       stax.Dense(10, 1., 0.05))

    key = stateless_uniform(shape=[2],
                            seed=[0, 0],
                            minval=None,
                            maxval=None,
                            dtype=tf.int32)
    _, params = init_fn(key, (1, 784))

    # Create and initialize an optimizer.
    opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate)
    state = opt_init(params)

    # Create an mse loss function and a gradient function.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y)))

    # Create an MSE predictor to solve the NTK equation in function space.
    ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=0)
    g_dd = ntk(x_train, None, params)
    g_td = ntk(x_test, x_train, params)
    predictor = nt.predict.gradient_descent_mse(g_dd, y_train)

    # Get initial values of the network in function space.
    fx_train = apply_fn(params, x_train)
    fx_test = apply_fn(params, x_test)

    # Train the network.
    train_steps = int(FLAGS.train_time // FLAGS.learning_rate)
    print('Training for {} steps'.format(train_steps))

    for i in range(train_steps):
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x_train, y_train), state)

    # Get predictions from analytic computation.
    print('Computing analytic prediction.')
    fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test, g_td)

    # Print out summary data comparing the linear / nonlinear model.
    util.print_summary('train', y_train, apply_fn(params, x_train), fx_train,
                       loss)
    util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)
Пример #8
0
 def _get_inputs(cls, out_logits, test_shape, train_shape):
     key = stateless_uniform(shape=[2],
                             seed=[0, 0],
                             minval=None,
                             maxval=None,
                             dtype=tf.int32)
     keys = tf_random_split(key)
     key = keys[0]
     split = keys[1]
     x_train = np.asarray(normal(train_shape, seed=split))
     keys = tf_random_split(key)
     key = keys[0]
     split = keys[1]
     y_train = np.asarray(
         stateless_uniform(shape=(train_shape[0], out_logits),
                           seed=split,
                           minval=0,
                           maxval=1) < 0.5, np.float32)
     keys = tf_random_split(key)
     key = keys[0]
     split = keys[1]
     x_test = np.asarray(normal(test_shape, seed=split))
     return key, x_test, x_train, y_train
    def testAxes(self, diagonal_axes, trace_axes):
        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        splits = tf_random_split(seed=tf.convert_to_tensor(key,
                                                           dtype=tf.int32),
                                 num=3)
        key = splits[0]
        self_split = splits[1]
        other_split = splits[2]
        data_self = np.asarray(normal((4, 5, 6, 3), seed=self_split))
        data_other = np.asarray(normal((2, 5, 6, 3), seed=other_split))

        _diagonal_axes = utils.canonicalize_axis(diagonal_axes, data_self)
        _trace_axes = utils.canonicalize_axis(trace_axes, data_self)

        if any(d == c for d in _diagonal_axes for c in _trace_axes):
            raise absltest.SkipTest(
                'diagonal axes must be different from channel axes.')

        implicit, direct, nngp = KERNELS['empirical_logits_3'](
            key, (5, 6, 3),
            CONV,
            diagonal_axes=diagonal_axes,
            trace_axes=trace_axes)

        n_marg = len(_diagonal_axes)
        n_chan = len(_trace_axes)

        g = implicit(data_self, None)
        g_direct = direct(data_self, None)
        g_nngp = nngp(data_self, None)

        self.assertAllClose(g, g_direct)
        self.assertEqual(g_nngp.shape, g.shape)
        self.assertEqual(2 * (data_self.ndim - n_chan) - n_marg, g_nngp.ndim)

        if 0 not in _trace_axes and 0 not in _diagonal_axes:
            g = implicit(data_other, data_self)
            g_direct = direct(data_other, data_self)
            g_nngp = nngp(data_other, data_self)

            self.assertAllClose(g, g_direct)
            self.assertEqual(g_nngp.shape, g.shape)
            self.assertEqual(2 * (data_self.ndim - n_chan) - n_marg,
                             g_nngp.ndim)
Пример #10
0
        def kernel_fn(x1,
                      x2,
                      do_flip,
                      keys,
                      do_square,
                      params,
                      _unused=None,
                      p=0.65):
            res = np.abs(np.matmul(x1, x2))
            if do_square:
                res *= res
            if do_flip:
                res = -res

            res *= stateless_uniform(shape=[], seed=keys) * p
            return [res, params]
Пример #11
0
    def testGradientDescentMseEnsembleTrain(self):
        key = stateless_uniform(shape=[2],
                                seed=[1, 1],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        x = np.asarray(normal((8, 4, 6, 3), seed=key))
        _, _, kernel_fn = stax.serial(stax.Conv(1, (2, 2)), stax.Relu(),
                                      stax.Conv(1, (2, 1)))
        y = np.asarray(normal((8, 2, 5, 1), seed=key))
        predictor = predict.gradient_descent_mse_ensemble(kernel_fn, x, y)

        for t in [None, np.array([0., 1., 10.])]:
            with self.subTest(t=t):
                y_none = predictor(t, None, None, compute_cov=True)
                y_x = predictor(t, x, None, compute_cov=True)
                self._assertAllClose(y_none, y_x, 0.04)
Пример #12
0
 def apply_fun(params, inputs, **kwargs):
   rng = kwargs.get('rng', None)
   if rng is None:
     msg = ("Dropout layer requires apply_fun to be called with a PRNG key "
            "argument. That is, instead of `apply_fun(params, inputs)`, call "
            "it like `apply_fun(params, inputs, rng)` where `rng` is a "
            "jax.random.PRNGKey value.")
     raise ValueError(msg)
   if mode == 'train':
     # keep = random.bernoulli(rng, rate, inputs.shape)
     # bernoulli = tfp.distributions.Bernoulli(probs=rate)
     # keep = bernoulli.sample(sample_shape=inputs.shape, seed=rng[0])
     prob = tf.ones(inputs.shape) * rate
     keep = stateless_uniform(shape=inputs.shape, seed=rng, minval=0, maxval=1) < prob
     return tfnp.where(keep, inputs / rate, 0)
   else:
     return inputs
Пример #13
0
 def get_samples(x1: np.ndarray, x2: Optional[np.ndarray], get: Get,
                 **apply_fn_kwargs):
     _key = stateless_uniform(shape=[2],
                              seed=key,
                              minval=None,
                              maxval=None,
                              dtype=tf.int32)
     ker_sampled = None
     for n in range(1, max(n_samples) + 1):
         _key, split = tf_split(_key)
         one_sample = kernel_fn_sample_once(x1, x2, split, get,
                                            **apply_fn_kwargs)
         if ker_sampled is None:
             ker_sampled = one_sample
         else:
             ker_sampled = tree_multimap(operator.add, ker_sampled,
                                         one_sample)
         yield n, ker_sampled
Пример #14
0
    def testSerial(self, train_shape, test_shape, network, name, kernel_fn,
                   batch_size):
        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        keys = tf_random_split(key, 3)
        key = keys[0]
        self_split = keys[1]
        other_split = keys[2]
        data_self = np.asarray(normal(train_shape, seed=self_split))
        data_other = np.asarray(normal(test_shape, seed=other_split))
        kernel_fn = kernel_fn(key, train_shape[1:], network)
        kernel_batched = batch._serial(kernel_fn, batch_size=batch_size)

        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)
    def testLinearization(self, shape):
        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        splits = tf_random_split(seed=tf.convert_to_tensor(key,
                                                           dtype=tf.int32),
                                 num=4)
        key = splits[0]
        s1 = splits[1]
        s2 = splits[2]
        s3 = splits[3]
        w1 = np.asarray(normal(shape, seed=s1))
        w1 = 0.5 * (w1 + w1.T)
        w2 = np.asarray(normal(shape, seed=s2))
        b = np.asarray(normal((shape[-1], ), seed=s3))
        params = (w1, w2, b)

        splits = tf_random_split(seed=tf.convert_to_tensor(key,
                                                           dtype=tf.int32),
                                 num=2)
        key = splits[0]
        split = splits[1]
        x0 = np.asarray(normal((shape[-1], ), seed=split))

        f_lin = empirical.linearize(EmpiricalTest.f, x0)

        for _ in range(TAYLOR_RANDOM_SAMPLES):
            for do_alter in [True, False]:
                for do_shift_x in [True, False]:
                    splits = tf_random_split(seed=tf.convert_to_tensor(
                        key, dtype=tf.int32),
                                             num=2)
                    key = splits[0]
                    split = splits[1]
                    x = np.asarray(normal((shape[-1], ), seed=split))
                    self.assertAllClose(
                        EmpiricalTest.f_lin_exact(x0,
                                                  x,
                                                  params,
                                                  do_alter,
                                                  do_shift_x=do_shift_x),
                        f_lin(x, params, do_alter, do_shift_x=do_shift_x))
Пример #16
0
    def testParallel(self, train_shape, test_shape, network, name, kernel_fn):
        test_utils.stub_out_pmap(batch, 2)
        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        keys = tf_random_split(key, 3)
        key = keys[0]
        self_split = keys[1]
        other_split = keys[2]
        data_self = np.asarray(normal(train_shape, seed=self_split))
        data_other = np.asarray(normal(test_shape, seed=other_split))

        kernel_fn = kernel_fn(key, train_shape[1:], network, use_dropout=False)
        kernel_batched = batch._parallel(kernel_fn)

        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other, True)
Пример #17
0
def _get_inputs_and_model(width=1, n_classes=2, use_conv=True):
    key = stateless_uniform(shape=[2],
                            seed=[1, 1],
                            minval=None,
                            maxval=None,
                            dtype=tf.int32)
    keys = tf_random_split(key)
    key = keys[0]
    split = keys[1]
    x1 = np.asarray(normal((8, 4, 3, 2), seed=key))
    x2 = np.asarray(normal((4, 4, 3, 2), seed=split))

    if not use_conv:
        x1 = np.reshape(x1, (x1.shape[0], -1))
        x2 = np.reshape(x2, (x2.shape[0], -1))

    init_fn, apply_fn, kernel_fn = stax.serial(
        stax.Conv(width, (3, 3)) if use_conv else stax.Dense(width),
        stax.Relu(), stax.Flatten(), stax.Dense(n_classes, 2., 0.5))
    return x1, x2, init_fn, apply_fn, kernel_fn, key
Пример #18
0
    def test_jit_or_pmap_broadcast(self):
        def kernel_fn(x1,
                      x2,
                      do_flip,
                      keys,
                      do_square,
                      params,
                      _unused=None,
                      p=0.65):
            res = np.abs(np.matmul(x1, x2))
            if do_square:
                res *= res
            if do_flip:
                res = -res

            res *= stateless_uniform(shape=[], seed=keys) * p
            return [res, params]

        params = (np.array([1., 0.3]), (np.array([1.2]), np.array([0.5])))
        x2 = np.arange(0, 10).reshape((10, ))
        keys = stateless_uniform(shape=[2],
                                 seed=[1, 1],
                                 minval=None,
                                 maxval=None,
                                 dtype=tf.int32)

        kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn,
                                                         device_count=0)
        x1 = np.arange(0, 10).reshape((1, 10))
        for do_flip in [True, False]:
            for do_square in [True, False]:
                with self.subTest(do_flip=do_flip,
                                  do_square=do_square,
                                  device_count=0):
                    res_1 = kernel_fn(x1,
                                      x2,
                                      do_flip,
                                      keys,
                                      do_square,
                                      params,
                                      _unused=True,
                                      p=0.65)
                    res_2 = kernel_fn_pmapped(x1,
                                              x2,
                                              do_flip,
                                              keys,
                                              do_square,
                                              params,
                                              _unused=True)
                    self.assertAllClose(res_1, res_2)

        test_utils.stub_out_pmap(batch, 1)
        x1 = np.arange(0, 10).reshape((1, 10))
        kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn,
                                                         device_count=1)
        for do_flip in [True, False]:
            for do_square in [True, False]:
                with self.subTest(do_flip=do_flip,
                                  do_square=do_square,
                                  device_count=1):
                    res_1 = kernel_fn(x1,
                                      x2,
                                      do_flip,
                                      keys,
                                      do_square,
                                      params,
                                      _unused=False,
                                      p=0.65)
                    res_2 = kernel_fn_pmapped(x1,
                                              x2,
                                              do_flip,
                                              keys,
                                              do_square,
                                              params,
                                              _unused=None)
                    self.assertAllClose(res_1[0], res_2[0])
                    self.assertAllClose(
                        tree_map(partial(np.expand_dims, axis=0), res_1[1]),
                        res_2[1])

        kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn,
                                                         device_count=2)
        x1 = np.arange(0, 20).reshape((2, 10))
        test_utils.stub_out_pmap(batch, 2)

        def broadcast(arg):
            return np.broadcast_to(arg, (2, ) + arg.shape)

        for do_flip in [True, False]:
            for do_square in [True, False]:
                with self.subTest(do_flip=do_flip,
                                  do_square=do_square,
                                  device_count=2):
                    res_1 = kernel_fn(x1,
                                      x2,
                                      do_flip,
                                      keys,
                                      do_square,
                                      params,
                                      p=0.2)
                    res_2 = kernel_fn_pmapped(x1,
                                              x2,
                                              do_flip,
                                              keys,
                                              do_square,
                                              params,
                                              _unused=None,
                                              p=0.2)
                    self.assertAllClose(res_1[0][0], res_2[0][0])
                    self.assertAllClose(res_1[0][1], res_2[0][1])
                    self.assertAllClose(tree_map(broadcast, res_1[1]),
                                        res_2[1])
Пример #19
0
def main(unused_argv):
    # Build data and .
    print('Loading data.')
    x_train, y_train, x_test, y_test = datasets.get_dataset('mnist',
                                                            permute_train=True)

    # Build the network
    init_fn, f, _ = stax.serial(stax.Dense(512, 1., 0.05), stax.Erf(),
                                stax.Dense(10, 1., 0.05))

    key = stateless_uniform(shape=[2],
                            seed=[0, 0],
                            minval=None,
                            maxval=None,
                            dtype=tf.int32)
    _, params = init_fn(key, (1, 784))

    # Linearize the network about its initial parameters.
    f_lin = nt.linearize(f, params)

    # Create and initialize an optimizer for both f and f_lin.
    opt_init, opt_apply, get_params = momentum(FLAGS.learning_rate, 0.9)
    # opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate)

    opt_apply = jit(opt_apply)

    state = opt_init(params)
    state_lin = opt_init(params)

    # momentum = MomentumOptimizer(learning_rate=FLAGS.learning_rate, momentum=0.9)
    # momentum_lin = MomentumOptimizer(learning_rate=FLAGS.learning_rate, momentum=0.9)

    # Create a cross-entropy loss function.
    loss = lambda fx, y_hat: -np.mean(log_softmax(fx) * y_hat)

    # Specialize the loss function to compute gradients for both linearized and
    # full networks.
    grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))
    grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y)))

    # Train the network.
    print('Training.')
    print('Epoch\tLoss\tLinearized Loss')
    print('------------------------------------------')

    epoch = 0
    steps_per_epoch = 50000 // FLAGS.batch_size

    for i, (x, y) in enumerate(
            datasets.minibatch(x_train, y_train, FLAGS.batch_size,
                               FLAGS.train_epochs)):

        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x, y), state)

        params_lin = get_params(state_lin)
        state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin)

        # x = np.asarray(x)
        # y = np.asarray(y)

        # momentum.apply_gradients((grad_loss(params, x, y), params))
        # momentum.apply_gradients((grad_loss_lin(params_lin, x, y), params_lin))

        if i % steps_per_epoch == 0:
            print('{}\t{}\t{}'.format(epoch, loss(f(params, x), y),
                                      loss(f_lin(params_lin, x), y)))
            epoch += 1

    # Print out summary data comparing the linear / nonlinear model.
    x, y = x_train[:10000], y_train[:10000]
    util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss)
    util.print_summary('test', y_test, f(params, x_test),
                       f_lin(params_lin, x_test), loss)
Пример #20
0
    def testGpInference(self):
        reg = 1e-5
        key = stateless_uniform(shape=[2],
                                seed=[1, 1],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        x_train = np.asarray(normal((4, 2), seed=key))
        init_fn, apply_fn, kernel_fn_analytic = stax.serial(
            stax.Dense(32, 2., 0.5), stax.Relu(), stax.Dense(10, 2., 0.5))
        y_train = np.asarray(normal((4, 10), seed=key))
        for kernel_fn_is_analytic in [True, False]:
            if kernel_fn_is_analytic:
                kernel_fn = kernel_fn_analytic
            else:
                _, params = init_fn(key, x_train.shape)
                kernel_fn_empirical = empirical.empirical_kernel_fn(apply_fn)

                def kernel_fn(x1, x2, get):
                    return kernel_fn_empirical(x1, x2, get, params)

            for get in [
                    None, 'nngp', 'ntk', ('nngp', ), ('ntk', ),
                ('nngp', 'ntk'), ('ntk', 'nngp')
            ]:
                k_dd = kernel_fn(x_train, None, get)

                gp_inference = predict.gp_inference(k_dd,
                                                    y_train,
                                                    diag_reg=reg)
                gd_ensemble = predict.gradient_descent_mse_ensemble(
                    kernel_fn, x_train, y_train, diag_reg=reg)
                for x_test in [None, 'x_test']:
                    x_test = None if x_test is None else np.asarray(
                        normal((8, 2), seed=key))
                    k_td = None if x_test is None else kernel_fn(
                        x_test, x_train, get)

                    for compute_cov in [True, False]:
                        with self.subTest(
                                kernel_fn_is_analytic=kernel_fn_is_analytic,
                                get=get,
                                x_test=x_test if x_test is None else 'x_test',
                                compute_cov=compute_cov):
                            if compute_cov:
                                nngp_tt = (True if x_test is None else
                                           kernel_fn(x_test, None, 'nngp'))
                            else:
                                nngp_tt = None

                            out_ens = gd_ensemble(None, x_test, get,
                                                  compute_cov)
                            out_ens_inf = gd_ensemble(np.inf, x_test, get,
                                                      compute_cov)
                            self._assertAllClose(out_ens_inf, out_ens, 0.08)

                            if (get is not None and 'nngp' not in get
                                    and compute_cov and k_td is not None):
                                with self.assertRaises(ValueError):
                                    out_gp_inf = gp_inference(
                                        get=get,
                                        k_test_train=k_td,
                                        nngp_test_test=nngp_tt)
                            else:
                                out_gp_inf = gp_inference(
                                    get=get,
                                    k_test_train=k_td,
                                    nngp_test_test=nngp_tt)
                                self.assertAllClose(out_ens, out_gp_inf)
    def testTaylorExpansion(self, shape):
        def f_2_exact(x0, x, params, do_alter, do_shift_x=True):
            w1, w2, b = params
            f_lin = EmpiricalTest.f_lin_exact(x0, x, params, do_alter,
                                              do_shift_x)
            if do_shift_x:
                x0 = x0 * 2 + 1.
                x = x * 2 + 1.
            if do_alter:
                b *= 2.
                w1 += 5.
                w2 /= 0.9
            dx = x - x0
            return f_lin + 0.5 * np.dot(np.dot(dx.T, w1), dx)

        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        splits = tf_random_split(seed=tf.convert_to_tensor(key,
                                                           dtype=tf.int32),
                                 num=4)
        key = splits[0]
        s1 = splits[1]
        s2 = splits[2]
        s3 = splits[3]
        w1 = np.asarray(normal(shape, seed=s1))
        w1 = 0.5 * (w1 + w1.T)
        w2 = np.asarray(normal(shape, seed=s2))
        b = np.asarray(normal((shape[-1], ), seed=s3))
        params = (w1, w2, b)

        splits = tf_random_split(seed=tf.convert_to_tensor(key,
                                                           dtype=tf.int32),
                                 num=2)
        key = splits[0]
        split = splits[1]
        x0 = np.asarray(normal((shape[-1], ), seed=split))

        f_lin = empirical.taylor_expand(EmpiricalTest.f, x0, 1)
        f_2 = empirical.taylor_expand(EmpiricalTest.f, x0, 2)

        for _ in range(TAYLOR_RANDOM_SAMPLES):
            for do_alter in [True, False]:
                for do_shift_x in [True, False]:
                    splits = tf_random_split(seed=tf.convert_to_tensor(
                        key, dtype=tf.int32),
                                             num=2)
                    key = splits[0]
                    split = splits[1]
                    x = np.asarray(normal((shape[-1], ), seed=split))
                    self.assertAllClose(
                        EmpiricalTest.f_lin_exact(x0,
                                                  x,
                                                  params,
                                                  do_alter,
                                                  do_shift_x=do_shift_x),
                        f_lin(x, params, do_alter, do_shift_x=do_shift_x))
                    self.assertAllClose(
                        f_2_exact(x0,
                                  x,
                                  params,
                                  do_alter,
                                  do_shift_x=do_shift_x),
                        f_2(x, params, do_alter, do_shift_x=do_shift_x))
Пример #22
0
    def testPredictND(self):
        n_chan = 6
        key = stateless_uniform(shape=[2],
                                seed=[1, 1],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        im_shape = (5, 4, 3)
        n_train = 2
        n_test = 2
        x_train = np.asarray(normal((n_train, ) + im_shape, seed=key))
        y_train = stateless_uniform(shape=(n_train, 3, 2, n_chan), seed=key)
        init_fn, apply_fn, _ = stax.Conv(n_chan, (3, 2), (1, 2))
        _, params = init_fn(key, x_train.shape)
        fx_train_0 = apply_fn(params, x_train)

        for trace_axes in [(), (-1, ), (-2, ), (-3, ), (0, 1), (2, 3), (2, ),
                           (1, 3), (0, -1), (0, 0, -3), (0, 1, 2, 3),
                           (0, 1, -1, 2)]:
            for ts in [None, np.arange(6).reshape((2, 3))]:
                for x in [None, 'x_test']:
                    with self.subTest(trace_axes=trace_axes, ts=ts, x=x):
                        t_shape = ts.shape if ts is not None else ()
                        y_test_shape = t_shape + (n_test, ) + y_train.shape[1:]
                        y_train_shape = t_shape + y_train.shape
                        x = x if x is None else np.asarray(
                            normal((n_test, ) + im_shape, seed=key))
                        fx_test_0 = None if x is None else apply_fn(params, x)

                        kernel_fn = empirical.empirical_kernel_fn(
                            apply_fn, trace_axes=trace_axes)

                        # TODO(romann): investigate the SIGTERM error on CPU.
                        # kernel_fn = jit(kernel_fn, static_argnums=(2,))
                        ntk_train_train = kernel_fn(x_train, None, 'ntk',
                                                    params)
                        if x is not None:
                            ntk_test_train = kernel_fn(x, x_train, 'ntk',
                                                       params)

                        loss = lambda x, y: 0.5 * np.mean(x - y)**2
                        predict_fn_mse = predict.gradient_descent_mse(
                            ntk_train_train, y_train, trace_axes=trace_axes)

                        predict_fn_mse_ensemble = predict.gradient_descent_mse_ensemble(
                            kernel_fn,
                            x_train,
                            y_train,
                            trace_axes=trace_axes,
                            params=params)

                        if x is None:
                            p_train_mse = predict_fn_mse(ts, fx_train_0)
                        else:
                            p_train_mse, p_test_mse = predict_fn_mse(
                                ts, fx_train_0, fx_test_0, ntk_test_train)
                            self.assertAllClose(y_test_shape, p_test_mse.shape)
                        self.assertAllClose(y_train_shape, p_train_mse.shape)

                        p_nngp_mse_ens, p_ntk_mse_ens = predict_fn_mse_ensemble(
                            ts, x, ('nngp', 'ntk'), compute_cov=True)
                        ref_shape = y_train_shape if x is None else y_test_shape
                        self.assertAllClose(ref_shape,
                                            p_ntk_mse_ens.mean.shape)
                        self.assertAllClose(ref_shape,
                                            p_nngp_mse_ens.mean.shape)

                        if ts is not None:
                            predict_fn = predict.gradient_descent(
                                loss,
                                ntk_train_train,
                                y_train,
                                trace_axes=trace_axes)

                            if x is None:
                                p_train = predict_fn(ts, fx_train_0)
                            else:
                                p_train, p_test = predict_fn(
                                    ts, fx_train_0, fx_test_0, ntk_test_train)
                                self.assertAllClose(y_test_shape, p_test.shape)
                            self.assertAllClose(y_train_shape, p_train.shape)