Exemplo n.º 1
0
        def predict_mc(count, key):
            key = tf_random_split(key, count)
            fx_train, fx_test = vmap(predict_empirical)(key)
            fx_train_mean = np.mean(fx_train, axis=0)
            fx_test_mean = np.mean(fx_test, axis=0)

            fx_train_centered = fx_train - fx_train_mean
            fx_test_centered = fx_test - fx_test_mean

            cov_train = PredictTest._cov_empirical(fx_train_centered)
            cov_test = PredictTest._cov_empirical(fx_test_centered)

            return fx_train_mean, fx_test_mean, cov_train, cov_test
Exemplo n.º 2
0
def main(unused_argv):
  # Build data pipelines.
  print('Loading data.')
  x_train, y_train, x_test, y_test = \
    datasets.get_dataset('cifar10', FLAGS.train_size, FLAGS.test_size)

  # Build the infinite network.
  _, _, kernel_fn = stax.serial(
      stax.Dense(1, 2., 0.05),
      stax.Relu(),
      stax.Dense(1, 2., 0.05)
  )
  # Optionally, compute the kernel in batches, in parallel.
  kernel_fn = nt.batch(kernel_fn,
                       device_count=0,
                       batch_size=FLAGS.batch_size)

  start = time.time()
  # Bayesian and infinite-time gradient descent inference with infinite network.
  predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train,
                                                        y_train, diag_reg=1e-3)
  fx_test_nngp, fx_test_ntk = predict_fn(x_test=x_test)

  duration = time.time() - start
  print('Kernel construction and inference done in %s seconds.' % duration)

  # Print out accuracy and loss for infinite network predictions.
  loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2)
  util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss)
  util.print_summary('NTK test', y_test, fx_test_ntk, None, loss)
Exemplo n.º 3
0
    def testMaxLearningRate(self, train_shape, network, out_logits,
                            fn_and_kernel):

        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)

        keys = tf_random_split(key)
        key = keys[0]
        split = keys[1]
        if len(train_shape) == 2:
            train_shape = (train_shape[0] * 5, train_shape[1] * 10)
        else:
            train_shape = (16, 8, 8, 3)
        x_train = np.asarray(normal(train_shape, seed=split))

        keys = tf_random_split(key)
        key = keys[0]
        split = keys[1]
        y_train = np.asarray(
            stateless_uniform(shape=(train_shape[0], out_logits),
                              seed=split,
                              minval=0,
                              maxval=1) < 0.5, np.float32)
        # Regress to an MSE loss.
        loss = lambda params, x: 0.5 * np.mean((f(params, x) - y_train)**2)
        grad_loss = jit(grad(loss))

        def get_loss(opt_state):
            return loss(get_params(opt_state), x_train)

        steps = 20

        for lr_factor in [0.5, 3.]:
            params, f, ntk = fn_and_kernel(key, train_shape[1:], network,
                                           out_logits)
            g_dd = ntk(x_train, None, 'ntk')

            step_size = predict.max_learning_rate(
                g_dd, y_train_size=y_train.size) * lr_factor
            opt_init, opt_update, get_params = optimizers.sgd(step_size)
            opt_state = opt_init(params)

            init_loss = get_loss(opt_state)

            for i in range(steps):
                params = get_params(opt_state)
                opt_state = opt_update(i, grad_loss(params, x_train),
                                       opt_state)

            trained_loss = get_loss(opt_state)
            loss_ratio = trained_loss / (init_loss + 1e-12)
            if lr_factor == 3.:
                if not math.isnan(loss_ratio):
                    self.assertGreater(loss_ratio, 10.)
            else:
                self.assertLess(loss_ratio, 0.1)
Exemplo n.º 4
0
def print_summary(name, labels, net_p, lin_p, loss):
    """Print summary information comparing a network with its linearization."""
    print('\nEvaluating Network on {} data.'.format(name))
    print('---------------------------------------')
    print('Network Accuracy = {}'.format(_accuracy(net_p, labels)))
    print('Network Loss = {}'.format(loss(net_p, labels)))
    if lin_p is not None:
        print('Linearization Accuracy = {}'.format(_accuracy(lin_p, labels)))
        print('Linearization Loss = {}'.format(loss(lin_p, labels)))
        print('RMSE of predictions: {}'.format(
            np.sqrt(np.mean((net_p - lin_p)**2))))
    print('---------------------------------------')
Exemplo n.º 5
0
def main(unused_argv):
    # Build data pipelines.
    print('Loading data.')
    x_train, y_train, x_test, y_test = \
        datasets.get_dataset('mnist', FLAGS.train_size, FLAGS.test_size)

    # Build the network
    init_fn, apply_fn, _ = stax.serial(stax.Dense(512, 1., 0.05), stax.Erf(),
                                       stax.Dense(10, 1., 0.05))

    key = random.stateless_random_uniform(shape=[2],
                                          seed=[0, 0],
                                          minval=None,
                                          maxval=None,
                                          dtype=np.int32)
    _, params = init_fn(key, (1, 784))

    # Create and initialize an optimizer.
    opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate)
    state = opt_init(params)

    # Create an mse loss function and a gradient function.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y)))

    # Create an MSE predictor to solve the NTK equation in function space.
    ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=0)
    g_dd = ntk(x_train, None, params)
    g_td = ntk(x_test, x_train, params)
    predictor = nt.predict.gradient_descent_mse(g_dd, y_train)

    # Get initial values of the network in function space.
    fx_train = apply_fn(params, x_train)
    fx_test = apply_fn(params, x_test)

    # Train the network.
    train_steps = int(FLAGS.train_time // FLAGS.learning_rate)
    print('Training for {} steps'.format(train_steps))

    for i in range(train_steps):
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x_train, y_train), state)

    # Get predictions from analytic computation.
    print('Computing analytic prediction.')
    fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test, g_td)

    # Print out summary data comparing the linear / nonlinear model.
    util.print_summary('train', y_train, apply_fn(params, x_train), fx_train,
                       loss)
    util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)
Exemplo n.º 6
0
        def run_test(*args):
            num_samples = 1000
            tol = 0.1  # High tolerance to keep the # of samples low else the test
            # takes a long time to run.
            np_random.seed(10)
            outputs = [np_random.randn(*args) for _ in range(num_samples)]

            # Test output shape.
            for output in outputs:
                self.assertEqual(output.shape, tuple(args))
                default_dtype = (np.float64 if np_dtypes.is_allow_float64()
                                 else np.float32)
                self.assertEqual(output.dtype.as_numpy_dtype, default_dtype)

            if np.prod(args):  # Don't bother with empty arrays.
                outputs = [output.tolist() for output in outputs]

                # Test that the properties of normal distribution are satisfied.
                mean = np.mean(outputs, axis=0)
                stddev = np.std(outputs, axis=0)
                self.assertAllClose(mean, np.zeros(args), atol=tol)
                self.assertAllClose(stddev, np.ones(args), atol=tol)

                # Test that outputs are different with different seeds.
                np_random.seed(20)
                diff_seed_outputs = [
                    np_random.randn(*args).tolist() for _ in range(num_samples)
                ]
                self.assertNotAllClose(outputs, diff_seed_outputs)

                # Test that outputs are the same with the same seed.
                np_random.seed(10)
                same_seed_outputs = [
                    np_random.randn(*args).tolist() for _ in range(num_samples)
                ]
                self.assertAllClose(outputs, same_seed_outputs)
Exemplo n.º 7
0
def main(unused_argv):
    # Build data and .
    print('Loading data.')
    x_train, y_train, x_test, y_test = datasets.get_dataset('mnist',
                                                            permute_train=True)

    # Build the network
    init_fn, f, _ = stax.serial(stax.Dense(512, 1., 0.05), stax.Erf(),
                                stax.Dense(10, 1., 0.05))

    key = random.stateless_random_uniform(shape=[2],
                                          seed=[0, 0],
                                          minval=None,
                                          maxval=None,
                                          dtype=np.int32)
    _, params = init_fn(key, (1, 784))

    # Linearize the network about its initial parameters.
    f_lin = nt.linearize(f, params)

    # Create and initialize an optimizer for both f and f_lin.
    opt_init, opt_apply, get_params = optimizers.momentum(
        FLAGS.learning_rate, 0.9)
    opt_apply = jit(opt_apply)

    state = opt_init(params)
    state_lin = opt_init(params)

    # Create a cross-entropy loss function.
    loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat)

    # Specialize the loss function to compute gradients for both linearized and
    # full networks.
    grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))
    grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y)))

    # Train the network.
    print('Training.')
    print('Epoch\tLoss\tLinearized Loss')
    print('------------------------------------------')

    epoch = 0
    steps_per_epoch = 50000 // FLAGS.batch_size

    for i, (x, y) in enumerate(
            datasets.minibatch(x_train, y_train, FLAGS.batch_size,
                               FLAGS.train_epochs)):

        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x, y), state)

        params_lin = get_params(state_lin)
        state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin)

        if i % steps_per_epoch == 0:
            print('{}\t{}\t{}'.format(epoch, loss(f(params, x), y),
                                      loss(f_lin(params_lin, x), y)))
            epoch += 1

    # Print out summary data comparing the linear / nonlinear model.
    x, y = x_train[:10000], y_train[:10000]
    util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss)
    util.print_summary('test', y_test, f(params, x_test),
                       f_lin(params_lin, x_test), loss)
Exemplo n.º 8
0
    def testPredictND(self):
        n_chan = 6
        key = stateless_uniform(shape=[2],
                                seed=[1, 1],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        im_shape = (5, 4, 3)
        n_train = 2
        n_test = 2
        x_train = np.asarray(normal((n_train, ) + im_shape, seed=key))
        y_train = stateless_uniform(shape=(n_train, 3, 2, n_chan), seed=key)
        init_fn, apply_fn, _ = stax.Conv(n_chan, (3, 2), (1, 2))
        _, params = init_fn(key, x_train.shape)
        fx_train_0 = apply_fn(params, x_train)

        for trace_axes in [(), (-1, ), (-2, ), (-3, ), (0, 1), (2, 3), (2, ),
                           (1, 3), (0, -1), (0, 0, -3), (0, 1, 2, 3),
                           (0, 1, -1, 2)]:
            for ts in [None, np.arange(6).reshape((2, 3))]:
                for x in [None, 'x_test']:
                    with self.subTest(trace_axes=trace_axes, ts=ts, x=x):
                        t_shape = ts.shape if ts is not None else ()
                        y_test_shape = t_shape + (n_test, ) + y_train.shape[1:]
                        y_train_shape = t_shape + y_train.shape
                        x = x if x is None else np.asarray(
                            normal((n_test, ) + im_shape, seed=key))
                        fx_test_0 = None if x is None else apply_fn(params, x)

                        kernel_fn = empirical.empirical_kernel_fn(
                            apply_fn, trace_axes=trace_axes)

                        # TODO(romann): investigate the SIGTERM error on CPU.
                        # kernel_fn = jit(kernel_fn, static_argnums=(2,))
                        ntk_train_train = kernel_fn(x_train, None, 'ntk',
                                                    params)
                        if x is not None:
                            ntk_test_train = kernel_fn(x, x_train, 'ntk',
                                                       params)

                        loss = lambda x, y: 0.5 * np.mean(x - y)**2
                        predict_fn_mse = predict.gradient_descent_mse(
                            ntk_train_train, y_train, trace_axes=trace_axes)

                        predict_fn_mse_ensemble = predict.gradient_descent_mse_ensemble(
                            kernel_fn,
                            x_train,
                            y_train,
                            trace_axes=trace_axes,
                            params=params)

                        if x is None:
                            p_train_mse = predict_fn_mse(ts, fx_train_0)
                        else:
                            p_train_mse, p_test_mse = predict_fn_mse(
                                ts, fx_train_0, fx_test_0, ntk_test_train)
                            self.assertAllClose(y_test_shape, p_test_mse.shape)
                        self.assertAllClose(y_train_shape, p_train_mse.shape)

                        p_nngp_mse_ens, p_ntk_mse_ens = predict_fn_mse_ensemble(
                            ts, x, ('nngp', 'ntk'), compute_cov=True)
                        ref_shape = y_train_shape if x is None else y_test_shape
                        self.assertAllClose(ref_shape,
                                            p_ntk_mse_ens.mean.shape)
                        self.assertAllClose(ref_shape,
                                            p_nngp_mse_ens.mean.shape)

                        if ts is not None:
                            predict_fn = predict.gradient_descent(
                                loss,
                                ntk_train_train,
                                y_train,
                                trace_axes=trace_axes)

                            if x is None:
                                p_train = predict_fn(ts, fx_train_0)
                            else:
                                p_train, p_test = predict_fn(
                                    ts, fx_train_0, fx_test_0, ntk_test_train)
                                self.assertAllClose(y_test_shape, p_test.shape)
                            self.assertAllClose(y_train_shape, p_train.shape)
Exemplo n.º 9
0
    def testTrainedEnsemblePredCov(self, train_shape, test_shape, network,
                                   out_logits):
        training_steps = 1000
        learning_rate = 0.1
        ensemble_size = 1024

        init_fn, apply_fn, kernel_fn = stax.serial(
            stax.Dense(128, W_std=1.2, b_std=0.05), stax.Erf(),
            stax.Dense(out_logits, W_std=1.2, b_std=0.05))

        opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
        opt_update = jit(opt_update)

        key, x_test, x_train, y_train = self._get_inputs(
            out_logits, test_shape, train_shape)
        predict_fn_mse_ens = predict.gradient_descent_mse_ensemble(
            kernel_fn,
            x_train,
            y_train,
            learning_rate=learning_rate,
            diag_reg=0.)

        train = (x_train, y_train)
        ensemble_key = tf_random_split(key, ensemble_size)

        loss = jit(lambda params, x, y: 0.5 * np.mean(
            (apply_fn(params, x) - y)**2))
        grad_loss = jit(lambda state, x, y: grad(loss)
                        (get_params(state), x, y))

        def train_network(key):
            _, params = init_fn(key, (-1, ) + train_shape[1:])
            opt_state = opt_init(params)
            for i in range(training_steps):
                opt_state = opt_update(i, grad_loss(opt_state, *train),
                                       opt_state)

            return get_params(opt_state)

        params = vmap(train_network)(ensemble_key)
        rtol = 0.08

        for x in [None, 'x_test']:
            with self.subTest(x=x):
                x = x if x is None else x_test
                x_fin = x_train if x is None else x_test
                ensemble_fx = vmap(apply_fn, (0, None))(params, x_fin)

                mean_emp = np.mean(ensemble_fx, axis=0)
                mean_subtracted = ensemble_fx - mean_emp
                cov_emp = np.einsum(
                    'ijk,ilk->jl',
                    mean_subtracted,
                    mean_subtracted,
                    optimize=True) / (mean_subtracted.shape[0] *
                                      mean_subtracted.shape[-1])

                ntk = predict_fn_mse_ens(training_steps,
                                         x,
                                         'ntk',
                                         compute_cov=True)
                self._assertAllClose(mean_emp, ntk.mean, rtol)
                self._assertAllClose(cov_emp, ntk.covariance, rtol)
Exemplo n.º 10
0
    def testNTKGDPrediction(self, train_shape, test_shape, network, out_logits,
                            fn_and_kernel, momentum, learning_rate, t, loss):
        key, x_test, x_train, y_train = self._get_inputs(
            out_logits, test_shape, train_shape)

        params, f, ntk = fn_and_kernel(key, train_shape[1:], network,
                                       out_logits)

        g_dd = ntk(x_train, None, 'ntk')
        g_td = ntk(x_test, x_train, 'ntk')

        # Regress to an MSE loss.
        loss_fn = lambda y, y_hat: 0.5 * np.mean((y - y_hat)**2)
        grad_loss = jit(grad(lambda params, x: loss_fn(f(params, x), y_train)))

        trace_axes = () if g_dd.ndim == 4 else (-1, )
        if loss == 'mse_analytic':
            if momentum is not None:
                raise absltest.SkipTest(momentum)
            predictor = predict.gradient_descent_mse(
                g_dd,
                y_train,
                learning_rate=learning_rate,
                trace_axes=trace_axes)
        elif loss == 'mse':
            predictor = predict.gradient_descent(loss_fn,
                                                 g_dd,
                                                 y_train,
                                                 learning_rate=learning_rate,
                                                 momentum=momentum,
                                                 trace_axes=trace_axes)
        else:
            raise NotImplementedError(loss)

        predictor = jit(predictor)

        fx_train_0 = f(params, x_train)
        fx_test_0 = f(params, x_test)

        self._test_zero_time(predictor, fx_train_0, fx_test_0, g_td, momentum)
        self._test_multi_step(predictor, fx_train_0, fx_test_0, g_td, momentum)
        if loss == 'mse_analytic':
            self._test_inf_time(predictor, fx_train_0, fx_test_0, g_td,
                                y_train)

        if momentum is None:
            opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
        else:
            opt_init, opt_update, get_params = optimizers.momentum(
                learning_rate, momentum)

        opt_state = opt_init(params)
        for i in range(t):
            params = get_params(opt_state)
            opt_state = opt_update(i, grad_loss(params, x_train), opt_state)

        params = get_params(opt_state)

        fx_train_nn, fx_test_nn = f(params, x_train), f(params, x_test)
        fx_train_t, fx_test_t = predictor(t, fx_train_0, fx_test_0, g_td)

        self.assertAllClose(fx_train_nn, fx_train_t, rtol=RTOL, atol=ATOL)
        self.assertAllClose(fx_test_nn, fx_test_t, rtol=RTOL, atol=ATOL)
Exemplo n.º 11
0
def _accuracy(y, y_hat):
    """Compute the accuracy of the predictions with respect to one-hot labels."""
    return np.mean(np.argmax(y, axis=1) == np.argmax(y_hat, axis=1))