Ejemplo n.º 1
0
def _kernel_fns(key,
                input_shape,
                network,
                out_logits,
                diagonal_axes,
                trace_axes,
                vmap_axes=None):
    init_fn, f, _ = _build_network(input_shape, network, out_logits)
    _, params = init_fn(key, (-1, ) + input_shape)
    implicit_kernel_fn = jit(
        nt.empirical_ntk_fn(f,
                            trace_axes,
                            diagonal_axes,
                            vmap_axes,
                            implementation=2))
    direct_kernel_fn = jit(
        nt.empirical_ntk_fn(f,
                            trace_axes,
                            diagonal_axes,
                            vmap_axes,
                            implementation=1))

    nngp_kernel_fn = jit(nt.empirical_nngp_fn(f, trace_axes, diagonal_axes))

    return (partial(implicit_kernel_fn,
                    params=params), partial(direct_kernel_fn, params=params),
            partial(nngp_kernel_fn, params=params))
Ejemplo n.º 2
0
    def test_parallel_nested(self, same_inputs):
        rng = random.PRNGKey(0)
        input_key1, input_key2, net_key = random.split(rng, 3)

        x1_1, x1_2, x1_3 = np.split(random.normal(input_key1, (3, 33)),
                                    (10, 21),
                                    axis=1)
        x2_1, x2_2, x2_3 = np.split(random.normal(input_key2, (4, 33)),
                                    (10, 21),
                                    axis=1)

        x1 = ([x1_1, x1_2], x1_3)
        x2 = ([x2_1, x2_2], x2_3) if not same_inputs else None

        def layer(N_out):
            return stax.parallel(
                stax.parallel(stax.Dense(N_out), stax.Dense(N_out + 1)),
                stax.Dense(N_out + 2))

        init_fn, apply_fn, _ = stax.serial(layer(1024), layer(1))

        _, params = init_fn(net_key, tree_map(np.shape, x1))
        implicit_kernel_fn = jit(
            nt.empirical_ntk_fn(apply_fn, implementation=2))
        direct_kernel_fn = jit(nt.empirical_ntk_fn(apply_fn, implementation=1))

        implicit_batched_kernel_fn = jit(
            nt.empirical_ntk_fn(apply_fn,
                                vmap_axes=([0, 0], 0),
                                implementation=2))
        direct_batched_kernel_fn = jit(
            nt.empirical_ntk_fn(apply_fn,
                                vmap_axes=([0, 0], 0),
                                implementation=1))

        k_direct = direct_kernel_fn(x1, x2, params)

        self.assertAllClose(k_direct, implicit_kernel_fn(x1, x2, params))
        self.assertAllClose(k_direct, direct_batched_kernel_fn(x1, x2, params))
        self.assertAllClose(k_direct,
                            implicit_batched_kernel_fn(x1, x2, params))

        nngp_kernel_fn = jit(nt.empirical_nngp_fn(apply_fn))
        nngp = nngp_kernel_fn(x1, x2, params)

        self.assertEqual(len(nngp), 2)
        nngp_shape = (3, 3 if same_inputs else 4)
        self.assertEqual(nngp[0][0].shape, nngp_shape)
        self.assertEqual(nngp[0][1].shape, nngp_shape)
        self.assertEqual(nngp[1].shape, nngp_shape)
Ejemplo n.º 3
0
def network_fns(layers, x_train):
    ## Create the model functions for each layer
    layer_fns = []
    kernel_fns = []
    emp_kernel_fns = []
    for i, layer in enumerate(layers):
        init_fn, apply_fn, kernel_fn = stax.serial(*(layers[:i + 1]))
        layer_fns += [jit(apply_fn)]
        kernel_fns += [jit(kernel_fn)]
        emp_kernel_fns += [
            jit(partial(nt.empirical_nngp_fn(layer_fns[i]), x_train, None))
        ]
    init_fn, apply_fn, kernel_fn = stax.serial(*layers)
    apply_fn = jit(apply_fn)
    kernel_fn = jit(kernel_fn)

    return init_fn, apply_fn, kernel_fn, layer_fns, kernel_fns, emp_kernel_fns
Ejemplo n.º 4
0
    def test_parallel_in_out_empirical(self, same_inputs):
        test_utils.stub_out_pmap(batching, 2)
        rng = random.PRNGKey(0)
        input_key1, input_key2, net_key = random.split(rng, 3)

        x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 4, 1))
        x1 = (x1_1, (x1_2, x1_3))

        if same_inputs:
            x2 = None
        else:
            x2_1, x2_2, x2_3 = random.normal(input_key2, (3, 8, 1))
            x2 = (x2_1, (x2_2, x2_3))

        def net(N_out):
            return stax.parallel(
                stax.Dense(N_out),
                stax.parallel(stax.Dense(N_out + 1), stax.Dense(N_out + 2)))

        # Check NNGP.
        init_fn, apply_fn, _ = net(WIDTH)
        _, params = init_fn(net_key, ((-1, 1), ((-1, 1), (-1, 1))))

        kernel_fn = jit(nt.empirical_nngp_fn(apply_fn))
        batch_kernel_fn = jit(batching.batch(kernel_fn, 2))

        test_utils.assert_close_matrices(self, kernel_fn(x1, x2, params),
                                         batch_kernel_fn(x1, x2, params), RTOL)

        # Check NTK.
        init_fn, apply_fn, _ = stax.serial(net(WIDTH), net(1))
        _, params = init_fn(net_key, ((-1, 1), ((-1, 1), (-1, 1))))

        kernel_fn = jit(nt.empirical_ntk_fn(apply_fn))
        batch_kernel_fn = jit(batching.batch(kernel_fn, 2))

        test_utils.assert_close_matrices(self, kernel_fn(x1, x2, params),
                                         batch_kernel_fn(x1, x2, params), RTOL)
Ejemplo n.º 5
0
def main(unused_argv):
    # Build data pipelines.
    print('Loading data.')
    x_train, y_train, x_test, y_test = \
        datasets.get_dataset('mnist', FLAGS.train_size, FLAGS.test_size)

    # Build the network
    init_fn, apply_fn, _ = stax.serial(
      stax.Dense(2048, 1., 0.05),
      # stax.Erf(),
      stax.Relu(),
      stax.Dense(2048, 1., 0.05),
      # stax.Erf(),
      stax.Relu(),
      stax.Dense(10, 1., 0.05))

    key = random.PRNGKey(0)
    _, params = init_fn(key, (-1, 784))

    # params

    # Create and initialize an optimizer.
    opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate)
    state = opt_init(params)
    # state


    # Create an mse loss function and a gradient function.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2)
    grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y)))

    # Create an MSE predictor to solve the NTK equation in function space.
    ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=0)
    g_dd = ntk(x_train, None, params)
    g_td = ntk(x_test, x_train, params)
    predictor = nt.predict.gradient_descent_mse(g_dd, y_train, g_td)
    # g_dd.shape

    m = FLAGS.train_size
    print(m)
    n = m*10
    m_test = FLAGS.test_size
    n_test = m_test*10
    # g_td.shape
    # predictor
    # g_dd
    # type(g_dd)
    # g_dd.shape
    theta = g_dd.transpose((0,2,1,3)).reshape(n,n)
    theta_test = ntk(x_test, None, params).transpose((0,2,1,3)).reshape(n_test,n_test)
    theta_tilde = g_td.transpose((0,2,1,3)).reshape(n_test,n)
    #NNGP
    K = nt.empirical_nngp_fn(apply_fn)(x_train,None,params)
    K = np.kron(theta,np.eye(10))
    K_test = nt.empirical_nngp_fn(apply_fn)(x_test,None,params)
    K_test = np.kron(theta_test,np.eye(10))
    K_tilde = nt.empirical_nngp_fn(apply_fn)(x_test,x_train,params)
    K_tilde = np.kron(theta_tilde,np.eye(10))

    decay_matrix = np.eye(n)-scipy.linalg.expm(-t*theta)
    Sigma = K + np.matmul(decay_matrix, np.matmul(K, np.matmul(np.linalg.inv(theta), np.matmul(decay_matrix, theta))) - 2*K)

    # K.shape
    theta
    # alpha = np.matmul(np.linalg.inv(K),np.matmul(theta,np.linalg.inv(theta)))
    # y_train
    # alpha = np.matmul(np.linalg.inv(K), y_train.reshape(1280))
    # Sigma = K + np.matmul()
    # K = theta
    sigma_noise = 1.0
    Y = y_train.reshape(n)
    alpha = np.matmul(np.linalg.inv(np.eye(n)*(sigma_noise**2)+K),Y)
    # cov = np.linalg.inv(np.linalg.inv(K)+np.eye(n)/(sigma_noise**2))
    # covi = np.linalg.inv(cov)
    # covi = np.linalg.inv(K)+np.eye(n)/(sigma_noise**2)
    # print(covi)
    # np.linalg.det(K)
    eigs = np.linalg.eigh(K)[0]
    logdetcoviK = np.sum(np.log((eigs+sigma_noise**2) /sigma_noise**2))
    # coviK = np.matmul(covi,K)
    # coviK = np.eye(n) + K/(sigma_noise**2)
    # coviK
    # covi
    # np.linalg.det()
    # KL = 0.5*np.log(np.linalg.det(coviK)) + 0.5*np.trace(np.linalg.inv(coviK)) + 0.5*np.matmul(alpha.T,np.matmul(K,alpha)) - n/2
    KL = 0.5*logdetcoviK + 0.5*np.trace(np.linalg.inv(coviK)) + 0.5*np.matmul(alpha.T,np.matmul(K,alpha)) - n/2
    print(KL)

    delta = 2**-10
    bound = (KL+2*np.log(m)+1-np.log(delta))/m
    bound = 1-np.exp(-bound)
    bound
    print("bound", bound)

    import numpy
    bigK = numpy.zeros((n+n_test,n+n_test))
    bigK
    bigK[0:n,0:n] = K
    bigK[0:n,n:] = theta_tilde.T
    bigK[n:,0:n] = theta_tilde
    bigK[n:,n:] = theta_test
    init_ntk_f = numpy.random.multivariate_normal(np.zeros(n+n_test),bigK)
    fx_train = init_ntk_f[:n].reshape(m,10)
    fx_test = init_ntk_f[n:].reshape(m_test,10)

    # Get initial values of the network in function space.
    # fx_train = apply_fn(params, x_train)
    # fx_test = apply_fn(params, x_test)

    # Train the network.
    train_steps = int(FLAGS.train_time // FLAGS.learning_rate)
    print('Training for {} steps'.format(train_steps))

    for i in range(train_steps):
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x_train, y_train), state)

    # Get predictions from analytic computation.
    print('Computing analytic prediction.')
    # fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test)
    fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test)

    # Print out summary data comparing the linear / nonlinear model.
    util.print_summary('train', y_train, apply_fn(params, x_train), fx_train, loss)
    util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)