Пример #1
0
    def testGpInference(self):
        reg = 1e-5
        key = random.PRNGKey(1)
        x_train = random.normal(key, (4, 2))
        init_fn, apply_fn, kernel_fn_analytic = stax.serial(
            stax.Dense(32, 2., 0.5), stax.Relu(), stax.Dense(10, 2., 0.5))
        y_train = random.normal(key, (4, 10))
        for kernel_fn_is_analytic in [True, False]:
            if kernel_fn_is_analytic:
                kernel_fn = kernel_fn_analytic
            else:
                _, params = init_fn(key, x_train.shape)
                kernel_fn_empirical = empirical.empirical_kernel_fn(apply_fn)

                def kernel_fn(x1, x2, get):
                    return kernel_fn_empirical(x1, x2, get, params)

            for get in [
                    None, 'nngp', 'ntk', ('nngp', ), ('ntk', ),
                ('nngp', 'ntk'), ('ntk', 'nngp')
            ]:
                k_dd = kernel_fn(x_train, None, get)

                gp_inference = predict.gp_inference(k_dd,
                                                    y_train,
                                                    diag_reg=reg)
                gd_ensemble = predict.gradient_descent_mse_ensemble(
                    kernel_fn, x_train, y_train, diag_reg=reg)
                for x_test in [None, 'x_test']:
                    x_test = None if x_test is None else random.normal(
                        key, (8, 2))
                    k_td = None if x_test is None else kernel_fn(
                        x_test, x_train, get)

                    for compute_cov in [True, False]:
                        with self.subTest(
                                kernel_fn_is_analytic=kernel_fn_is_analytic,
                                get=get,
                                x_test=x_test if x_test is None else 'x_test',
                                compute_cov=compute_cov):
                            if compute_cov:
                                nngp_tt = (True if x_test is None else
                                           kernel_fn(x_test, None, 'nngp'))
                            else:
                                nngp_tt = None

                            out_ens = gd_ensemble(None, x_test, get,
                                                  compute_cov)
                            out_ens_inf = gd_ensemble(np.inf, x_test, get,
                                                      compute_cov)
                            self._assertAllClose(out_ens_inf, out_ens, 0.08)

                            if (get is not None and 'nngp' not in get
                                    and compute_cov and k_td is not None):
                                with self.assertRaises(ValueError):
                                    out_gp_inf = gp_inference(
                                        get=get,
                                        k_test_train=k_td,
                                        nngp_test_test=nngp_tt)
                            else:
                                out_gp_inf = gp_inference(
                                    get=get,
                                    k_test_train=k_td,
                                    nngp_test_test=nngp_tt)
                                self.assertAllClose(out_ens, out_gp_inf)
Пример #2
0
 def net(N_out):
     return stax.parallel(
         stax.Dense(N_out),
         stax.parallel(stax.Dense(N_out + 1), stax.Dense(N_out + 2)))
Пример #3
0
def main(*args, use_dummy_data: bool = False, **kwargs) -> None:
    # Mask all padding with this value.
    mask_constant = 100.

    if use_dummy_data:
        x_train, y_train, x_test, y_test = _get_dummy_data(mask_constant)
    else:
        # Build data pipelines.
        print('Loading IMDb data.')
        x_train, y_train, x_test, y_test = datasets.get_dataset(
            name='imdb_reviews',
            n_train=_TRAIN_SIZE,
            n_test=_TEST_SIZE,
            do_flatten_and_normalize=False,
            data_dir=_IMDB_PATH,
            input_key='text')

        # Embed words and pad / truncate sentences to a fixed size.
        x_train, x_test = datasets.embed_glove(
            xs=[x_train, x_test],
            glove_path=_GLOVE_PATH,
            max_sentence_length=_MAX_SENTENCE_LENGTH,
            mask_constant=mask_constant)

    # Build the infinite network.
    # Not using the finite model, hence width is set to 1 everywhere.
    _, _, kernel_fn = stax.serial(
        stax.Conv(out_chan=1,
                  filter_shape=(9, ),
                  strides=(1, ),
                  padding='VALID'), stax.Relu(),
        stax.GlobalSelfAttention(n_chan_out=1,
                                 n_chan_key=1,
                                 n_chan_val=1,
                                 pos_emb_type='SUM',
                                 W_pos_emb_std=1.,
                                 pos_emb_decay_fn=lambda d: 1 / (1 + d**2),
                                 n_heads=1), stax.Relu(), stax.GlobalAvgPool(),
        stax.Dense(out_dim=1))

    # Optionally, compute the kernel in batches, in parallel.
    kernel_fn = nt.batch(kernel_fn, device_count=-1, batch_size=_BATCH_SIZE)

    start = time.time()
    # Bayesian and infinite-time gradient descent inference with infinite network.
    predict = nt.predict.gradient_descent_mse_ensemble(
        kernel_fn=kernel_fn,
        x_train=x_train,
        y_train=y_train,
        diag_reg=1e-6,
        mask_constant=mask_constant)

    fx_test_nngp, fx_test_ntk = predict(x_test=x_test, get=('nngp', 'ntk'))

    fx_test_nngp.block_until_ready()
    fx_test_ntk.block_until_ready()

    duration = time.time() - start
    print(f'Kernel construction and inference done in {duration} seconds.')

    # Print out accuracy and loss for infinite network predictions.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss)
    util.print_summary('NTK test', y_test, fx_test_ntk, None, loss)
Пример #4
0
    def testTrainedEnsemblePredCov(self, train_shape, test_shape, network,
                                   out_logits):
        training_steps = 1000
        learning_rate = 0.1
        ensemble_size = 1024

        init_fn, apply_fn, kernel_fn = stax.serial(
            stax.Dense(128, W_std=1.2, b_std=0.05), stax.Erf(),
            stax.Dense(out_logits, W_std=1.2, b_std=0.05))

        opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
        opt_update = jit(opt_update)

        key, x_test, x_train, y_train = self._get_inputs(
            out_logits, test_shape, train_shape)
        predict_fn_mse_ens = predict.gradient_descent_mse_ensemble(
            kernel_fn,
            x_train,
            y_train,
            learning_rate=learning_rate,
            diag_reg=0.)

        train = (x_train, y_train)
        ensemble_key = random.split(key, ensemble_size)

        loss = jit(lambda params, x, y: 0.5 * np.mean(
            (apply_fn(params, x) - y)**2))
        grad_loss = jit(lambda state, x, y: grad(loss)
                        (get_params(state), x, y))

        def train_network(key):
            _, params = init_fn(key, (-1, ) + train_shape[1:])
            opt_state = opt_init(params)
            for i in range(training_steps):
                opt_state = opt_update(i, grad_loss(opt_state, *train),
                                       opt_state)

            return get_params(opt_state)

        params = vmap(train_network)(ensemble_key)
        rtol = 0.08

        for x in [None, 'x_test']:
            with self.subTest(x=x):
                x = x if x is None else x_test
                x_fin = x_train if x is None else x_test
                ensemble_fx = vmap(apply_fn, (0, None))(params, x_fin)

                mean_emp = np.mean(ensemble_fx, axis=0, keepdims=True)
                mean_subtracted = ensemble_fx - mean_emp
                cov_emp = np.einsum(
                    'ijk,ilk->jl',
                    mean_subtracted,
                    mean_subtracted,
                    optimize=True) / (mean_subtracted.shape[0] *
                                      mean_subtracted.shape[-1])

                ntk = predict_fn_mse_ens(training_steps,
                                         x,
                                         'ntk',
                                         compute_cov=True)
                self._assertAllClose(mean_emp, ntk.mean, rtol)
                self._assertAllClose(cov_emp, ntk.covariance, rtol)
Пример #5
0
    def test_fan_in_fc(self, same_inputs, axis, n_branches, get, branch_in):
        if axis in (None, 0) and branch_in == 'dense_after_branch_in':
            raise jtu.SkipTest('`FanInSum` and `FanInConcat(0)` '
                               'require `is_gaussian`.')

        if axis == 1 and branch_in == 'dense_before_branch_in':
            raise jtu.SkipTest(
                '`FanInConcat` on feature axis requires a dense layer'
                'after concatenation.')

        key = random.PRNGKey(1)
        X0_1 = random.normal(key, (10, 20))
        X0_2 = None if same_inputs else random.normal(key, (8, 20))

        if xla_bridge.get_backend().platform == 'tpu':
            width = 2048
            n_samples = 1024
            tol = 0.02
        else:
            width = 1024
            n_samples = 256
            tol = 0.01

        dense = stax.Dense(width, 1.25, 0.1)
        input_layers = [dense, stax.FanOut(n_branches)]

        branches = []
        for b in range(n_branches):
            branch_layers = [FanInTest._get_phi(b)]
            for i in range(b):
                branch_layers += [
                    stax.Dense(width, 1. + 2 * i, 0.5 + i),
                    FanInTest._get_phi(i)
                ]

            if branch_in == 'dense_before_branch_in':
                branch_layers += [dense]
            branches += [stax.serial(*branch_layers)]

        output_layers = [
            stax.FanInSum() if axis is None else stax.FanInConcat(axis),
            stax.Relu()
        ]
        if branch_in == 'dense_after_branch_in':
            output_layers.insert(1, dense)

        nn = stax.serial(*(input_layers + [stax.parallel(*branches)] +
                           output_layers))

        if get == 'nngp':
            init_fn, apply_fn, kernel_fn = nn
        elif get == 'ntk':
            init_fn, apply_fn, kernel_fn = stax.serial(
                nn, stax.Dense(1, 1.25, 0.5))
        else:
            raise ValueError(get)

        kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn(init_fn,
                                                         apply_fn,
                                                         key,
                                                         n_samples,
                                                         device_count=0)

        exact = kernel_fn(X0_1, X0_2, get=get)
        empirical = kernel_fn_mc(X0_1, X0_2, get=get)
        empirical = empirical.reshape(exact.shape)
        test_utils.assert_close_matrices(self, empirical, exact, tol)
Пример #6
0
    def test_fan_in_conv(self, same_inputs, axis, n_branches, get, branch_in,
                         readout):
        if xla_bridge.get_backend().platform == 'cpu':
            raise jtu.SkipTest('Not running CNNs on CPU to save time.')

        if axis in (None, 0, 1, 2) and branch_in == 'dense_after_branch_in':
            raise jtu.SkipTest('`FanInSum` and `FanInConcat(0/1/2)` '
                               'require `is_gaussian`.')

        if axis == 3 and branch_in == 'dense_before_branch_in':
            raise jtu.SkipTest(
                '`FanInConcat` on feature axis requires a dense layer '
                'after concatenation.')

        key = random.PRNGKey(1)
        X0_1 = random.normal(key, (2, 5, 6, 3))
        X0_2 = None if same_inputs else random.normal(key, (3, 5, 6, 3))

        if xla_bridge.get_backend().platform == 'tpu':
            width = 2048
            n_samples = 1024
            tol = 0.02
        else:
            width = 1024
            n_samples = 512
            tol = 0.01

        conv = stax.Conv(out_chan=width,
                         filter_shape=(3, 3),
                         padding='SAME',
                         W_std=1.25,
                         b_std=0.1)

        input_layers = [conv, stax.FanOut(n_branches)]

        branches = []
        for b in range(n_branches):
            branch_layers = [FanInTest._get_phi(b)]
            for i in range(b):
                branch_layers += [
                    stax.Conv(out_chan=width,
                              filter_shape=(i + 1, 4 - i),
                              padding='SAME',
                              W_std=1.25 + i,
                              b_std=0.1 + i),
                    FanInTest._get_phi(i)
                ]

            if branch_in == 'dense_before_branch_in':
                branch_layers += [conv]
            branches += [stax.serial(*branch_layers)]

        output_layers = [
            stax.FanInSum() if axis is None else stax.FanInConcat(axis),
            stax.Relu(),
            stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten()
        ]
        if branch_in == 'dense_after_branch_in':
            output_layers.insert(1, conv)

        nn = stax.serial(*(input_layers + [stax.parallel(*branches)] +
                           output_layers))

        init_fn, apply_fn, kernel_fn = stax.serial(
            nn, stax.Dense(1 if get == 'ntk' else width, 1.25, 0.5))

        kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn(
            init_fn,
            apply_fn,
            key,
            n_samples,
            device_count=0 if axis in (0, -4) else -1)

        exact = kernel_fn(X0_1, X0_2, get=get)
        empirical = kernel_fn_mc(X0_1, X0_2, get=get)
        empirical = empirical.reshape(exact.shape)
        test_utils.assert_close_matrices(self, empirical, exact, tol)
Пример #7
0
    def testTrainedEnsemblePredCov(self, train_shape, test_shape, network,
                                   out_logits):
        if xla_bridge.get_backend().platform == 'gpu' and config.read(
                'jax_enable_x64'):
            raise jtu.SkipTest('Not running GPU x64 to save time.')
        training_steps = 5000
        learning_rate = 1.0
        ensemble_size = 50

        init_fn, apply_fn, ker_fn = stax.serial(
            stax.Dense(1024, W_std=1.2, b_std=0.05), stax.Erf(),
            stax.Dense(out_logits, W_std=1.2, b_std=0.05))

        opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
        opt_update = jit(opt_update)

        key = random.PRNGKey(0)
        key, = random.split(key, 1)

        key, split = random.split(key)
        x_train = np.cos(random.normal(split, train_shape))

        key, split = random.split(key)
        y_train = np.array(
            random.bernoulli(split, shape=(train_shape[0], out_logits)),
            np.float32)
        train = (x_train, y_train)
        key, split = random.split(key)
        x_test = np.cos(random.normal(split, test_shape))

        ensemble_key = random.split(key, ensemble_size)

        loss = jit(lambda params, x, y: 0.5 * np.mean(
            (apply_fn(params, x) - y)**2))
        grad_loss = jit(lambda state, x, y: grad(loss)
                        (get_params(state), x, y))

        def train_network(key):
            _, params = init_fn(key, (-1, ) + train_shape[1:])
            opt_state = opt_init(params)
            for i in range(training_steps):
                opt_state = opt_update(i, grad_loss(opt_state, *train),
                                       opt_state)

            return get_params(opt_state)

        params = vmap(train_network)(ensemble_key)

        ensemble_fx = vmap(apply_fn, (0, None))(params, x_test)
        ensemble_loss = vmap(loss, (0, None, None))(params, x_train, y_train)
        ensemble_loss = np.mean(ensemble_loss)
        self.assertLess(ensemble_loss, 1e-5, True)

        mean_emp = np.mean(ensemble_fx, axis=0)
        mean_subtracted = ensemble_fx - mean_emp
        cov_emp = np.einsum(
            'ijk,ilk->jl', mean_subtracted, mean_subtracted, optimize=True) / (
                mean_subtracted.shape[0] * mean_subtracted.shape[-1])

        reg = 1e-7
        ntk_predictions = predict.gp_inference(ker_fn,
                                               x_train,
                                               y_train,
                                               x_test,
                                               'ntk',
                                               reg,
                                               compute_cov=True)

        self.assertAllClose(mean_emp, ntk_predictions.mean, True, RTOL, ATOL)
        self.assertAllClose(cov_emp, ntk_predictions.covariance, True, RTOL,
                            ATOL)
Пример #8
0
def main(unused_argv):
    # Build data pipelines.
    print('Loading data.')
    x_train, y_train, x_test, y_test = \
      datasets.mnist(FLAGS.train_size, FLAGS.test_size)

    # x_train
    import numpy
    # numpy.argmax(y_train,1)%2
    # y_train_tmp = numpy.zeros((y_train.shape[0],2))
    # y_train_tmp[np.arange(y_train.shape[0]),numpy.argmax(y_train,1)%2] = 1
    # y_train = y_train_tmp
    # y_test_tmp = numpy.zeros((y_test.shape[0],2))
    # y_test_tmp[np.arange(y_train.shape[0]),numpy.argmax(y_test,1)%2] = 1
    # y_test = y_test_tmp

    y_train_tmp = numpy.argmax(y_train, 1) % 2
    y_train = np.expand_dims(y_train_tmp, 1)
    y_test_tmp = numpy.argmax(y_test, 1) % 2
    y_test = np.expand_dims(y_test_tmp, 1)
    # print(y_train)
    # Build the network
    # init_fn, apply_fn, _ = stax.serial(
    #   stax.Dense(2048, 1., 0.05),
    #   # stax.Erf(),
    #   stax.Relu(),
    #   stax.Dense(2048, 1., 0.05),
    #   # stax.Erf(),
    #   stax.Relu(),
    #   stax.Dense(10, 1., 0.05))
    init_fn, apply_fn, _ = stax.serial(stax.Dense(2048, 1., 0.05), stax.Erf(),
                                       stax.Dense(1, 1., 0.05))

    # key = random.PRNGKey(0)
    randnnn = numpy.random.random_integers(np.iinfo(np.int32).min,
                                           high=np.iinfo(np.int32).max,
                                           size=2)[0]
    key = random.PRNGKey(randnnn)
    _, params = init_fn(key, (-1, 784))

    # params

    # Create and initialize an optimizer.
    opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate)
    state = opt_init(params)
    # state

    # Create an mse loss function and a gradient function.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y)))

    # Create an MSE predictor to solve the NTK equation in function space.
    ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=0)
    g_dd = ntk(x_train, None, params)
    g_td = ntk(x_test, x_train, params)
    predictor = nt.predict.gradient_descent_mse(g_dd, y_train, g_td)
    # g_dd.shape

    # Get initial values of the network in function space.
    fx_train = apply_fn(params, x_train)
    fx_test = apply_fn(params, x_test)

    # Train the network.
    train_steps = int(FLAGS.train_time // FLAGS.learning_rate)
    print('Training for {} steps'.format(train_steps))

    for i in range(train_steps):
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x_train, y_train), state)

    # Get predictions from analytic computation.
    print('Computing analytic prediction.')
    # fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test)
    fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test)

    # Print out summary data comparing the linear / nonlinear model.
    util.print_summary('train', y_train, apply_fn(params, x_train), fx_train,
                       loss)
    util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)
Пример #9
0
def weight_space(train_embedding, test_embedding, data_set):
    init_fn, f, _ = stax.serial(
        stax.Dense(512, 1., 0.05),
        stax.Erf(),
        # 2 denotes 2 type of classes
        stax.Dense(2, 1., 0.05))

    key = random.PRNGKey(0)
    # (-1, 135),  135 denotes the feature length, here is 9 * 15 = 135
    _, params = init_fn(key, (-1, 135))

    # Linearize the network about its initial parameters.
    f_lin = nt.linearize(f, params)

    # Create and initialize an optimizer for both f and f_lin.
    opt_init, opt_apply, get_params = optimizers.momentum(1.0, 0.9)
    opt_apply = jit(opt_apply)

    state = opt_init(params)
    state_lin = opt_init(params)

    # Create a cross-entropy loss function.
    loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat)

    # Specialize the loss function to compute gradients for both linearized and
    # full networks.
    grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))
    grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y)))

    # Train the network.
    print('Training.')
    print('Epoch\tLoss\tLinearized Loss')
    print('------------------------------------------')

    epoch = 0
    # Use whole batch
    batch_size = 64
    train_epochs = 10
    steps_per_epoch = 100

    for i, (x, y) in enumerate(
            datasets.mini_batch(train_embedding, data_set['Y_train'],
                                batch_size, train_epochs)):
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x, y), state)

        params_lin = get_params(state_lin)
        state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin)

        if i % steps_per_epoch == 0:
            print('{}\t{:.4f}\t{:.4f}'.format(epoch, loss(f(params, x), y),
                                              loss(f_lin(params_lin, x), y)))
            epoch += 1
        if i / steps_per_epoch == train_epochs:
            break

    # Print out summary data comparing the linear / nonlinear model.
    x, y = train_embedding[:10000], data_set['Y_train'][:10000]
    util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss)
    util.print_summary('test', data_set['Y_test'], f(params, test_embedding),
                       f_lin(params_lin, test_embedding), loss)
Пример #10
0
def main(unused_argv):
  # Build data and .
  print('Loading data.')
  x_train, y_train, x_test, y_test = datasets.get_dataset('mnist',
                                                          permute_train=True)

  # Build the network
  init_fn, f, _ = stax.serial(
      stax.Dense(2048, 1., 0.05),
      stax.Erf(),
      stax.Dense(10, 1., 0.05))

  key = random.PRNGKey(0)
  _, params = init_fn(key, (-1, 784))

  # Linearize the network about its initial parameters.
  f_lin = nt.linearize(f, params)

  # Create and initialize an optimizer for both f and f_lin.
  opt_init, opt_apply, get_params = optimizers.momentum(FLAGS.learning_rate,
                                                        0.9)
  opt_apply = jit(opt_apply)

  state = opt_init(params)
  state_lin = opt_init(params)

  # Create a cross-entropy loss function.
  loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat)

  # Specialize the loss function to compute gradients for both linearized and
  # full networks.
  grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))
  grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y)))

  # Train the network.
  print('Training.')
  print('Epoch\tLoss\tLinearized Loss')
  print('------------------------------------------')

  epoch = 0
  steps_per_epoch = 50000 // FLAGS.batch_size

  for i, (x, y) in enumerate(datasets.minibatch(
      x_train, y_train, FLAGS.batch_size, FLAGS.train_epochs)):

    params = get_params(state)
    state = opt_apply(i, grad_loss(params, x, y), state)

    params_lin = get_params(state_lin)
    state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin)

    if i % steps_per_epoch == 0:
      print('{}\t{:.4f}\t{:.4f}'.format(
          epoch, loss(f(params, x), y), loss(f_lin(params_lin, x), y)))
      epoch += 1

  # Print out summary data comparing the linear / nonlinear model.
  x, y = x_train[:10000], y_train[:10000]
  util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss)
  util.print_summary(
      'test', y_test, f(params, x_test), f_lin(params_lin, x_test), loss)
Пример #11
0
def make_networks(
    spec,
    actor_hidden_layer_sizes=(256, 256),
    critic_hidden_layer_sizes=(256, 256),
    init_type='glorot_except_dist',
    critic_init_scale=1.0,
    use_double_q=True,
    img_encoder_fn=None,
    build_kernel_fn=False,
    ensemble_method='deep_ensembles',
    ensemble_size=None,  # this is not used for deep ensembles
    mimo_using_obs_tile=False,
    mimo_using_act_tile=False,
):
    """Creates networks used by the agent."""
    assert not (build_kernel_fn and (img_encoder_fn is not None))
    if ensemble_method not in [
            'deep_ensembles', 'mimo', 'tree_deep_ensembles',
            'efficient_tree_deep_ensembles'
    ]:
        raise NotImplementedError()

    num_dimensions = np.prod(spec.actions.shape, dtype=int)

    if init_type == 'glorot_except_dist':
        w_init = hk.initializers.VarianceScaling(1.0, "fan_avg",
                                                 "truncated_normal")
        b_init = jnp.zeros
        dist_w_init = hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform')
        dist_b_init = jnp.zeros
    elif init_type == 'glorot_also_dist':
        w_init = hk.initializers.VarianceScaling(1.0, "fan_avg",
                                                 "truncated_normal")
        b_init = jnp.zeros
        dist_w_init = hk.initializers.VarianceScaling(1.0, "fan_avg",
                                                      "truncated_normal")
        dist_b_init = jnp.zeros
    elif init_type == 'he_normal':
        w_init = hk.initializers.VarianceScaling(2.0, "fan_in",
                                                 "truncated_normal")
        b_init = jnp.zeros
        dist_w_init = w_init
        dist_b_init = b_init
    elif init_type == 'Ilya':
        assert False, 'This is not correct'
        relu_orthogonal = hk.initializers.Orthogonal(scale=2.0**0.5)
        near_zero_orthogonal = hk.initializers.Orthogonal(1e-2)
        w_init = relu_orthogonal
        b_init = jnp.zeros
        dist_w_init = near_zero_orthogonal
        dist_b_init = jnp.zeros
    else:
        raise NotImplementedError

    NUM_MIXTURE_COMPONENTS = 5  # if using gaussian mixtures
    rlu_uniform_initializer = hk.initializers.VarianceScaling(
        distribution='uniform', mode='fan_out', scale=0.333)

    # rlu_uniform_initializer = hk.initializers.VarianceScaling(scale=1e-4)
    def _actor_fn(obs):
        # # for matching Ilya's codebase
        # relu_orthogonal = hk.initializers.Orthogonal(scale=2.0**0.5)
        # near_zero_orthogonal = hk.initializers.Orthogonal(1e-2)
        # x = obs
        # for hid_dim in actor_hidden_layer_sizes:
        #   x = hk.Linear(hid_dim, w_init=relu_orthogonal, b_init=jnp.zeros)(x)
        #   x = jax.nn.relu(x)
        # dist = networks_lib.NormalTanhDistribution(
        #     num_dimensions,
        #     w_init=near_zero_orthogonal,
        #     b_init=jnp.zeros)(x)
        # return dist

        # w_init = hk.initializers.VarianceScaling(2.0, 'fan_in', 'uniform')
        # b_init = jnp.zeros

        # PAPER VERSION
        network = hk.Sequential([
            hk.nets.MLP(
                list(actor_hidden_layer_sizes),
                # w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'),
                # w_init=hk.initializers.VarianceScaling(1.0, "fan_avg", "truncated_normal"),
                w_init=w_init,
                b_init=b_init,
                activation=jax.nn.relu,
                # activation=jax.nn.tanh,
                activate_final=True),
            networks_lib.NormalTanhDistribution(
                num_dimensions,
                w_init=dist_w_init,
                b_init=dist_b_init,
                min_scale=1e-2,
            ),
            # networks_lib.MultivariateNormalDiagHead(
            #     num_dimensions,
            #     w_init=w_init,
            #     b_init=b_init),
            # networks_lib.GaussianMixture(
            #     num_dimensions,
            #     num_components=5,
            #     multivariate=True),
            # hk.Linear(
            #     NUM_MIXTURE_COMPONENTS + 2 * NUM_MIXTURE_COMPONENTS * num_dimensions,
            #     with_bias=True,
            #     w_init=dist_w_init,
            #     b_init=dist_b_init,),
        ])
        return network(obs)


#   def _actor_fn(obs):
#     # inspired by the ones used in RL Unplugged
#     x = obs
#     x = hk.Sequential([
#         hk.Linear(300, w_init=rlu_uniform_initializer),
#         hk.LayerNorm(axis=-1, create_scale=True, create_offset=True),
#         jax.lax.tanh,])(x)
#     x = hk.Linear(1024, w_init=rlu_uniform_initializer)(x)
#     for i in range(4):
#       x = network_utils.ResidualLayerNormBlock(
#           [1024, 1024],
#           activation=jax.nn.relu,
#           w_init=rlu_uniform_initializer,)(x)

#     # a = hk.Linear(
#     #     NUM_MIXTURE_COMPONENTS + 2 * NUM_MIXTURE_COMPONENTS * num_dimensions,
#     #     with_bias=True,
#     #     w_init=hk.initializers.VarianceScaling(scale=1e-5, mode='fan_in'),)(x)
#     a = networks_lib.NormalTanhDistribution(
#         num_dimensions,
#         w_init=dist_w_init,
#         b_init=dist_b_init,
#         min_scale=1e-2,)(x)
#     # a = networks_lib.MultivariateNormalDiagHead(
#     #     num_dimensions,
#     #     min_scale=1e-2,
#     #     w_init=dist_w_init,
#     #     b_init=dist_b_init,)(x)
#     return a

    critic_output_dim = 1
    if ensemble_method in [
            'mimo', 'tree_deep_ensembles', 'efficient_tree_deep_ensembles'
    ]:
        critic_output_dim = ensemble_size

    def small_critic(x):
        # i.e. what people typically use for d4rl benchmark
        _mlp = hk.nets.MLP(
            list(critic_hidden_layer_sizes),
            # w_init=hk.initializers.VarianceScaling(1.0, "fan_avg", "truncated_normal"),
            # w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'),
            # w_init=hk.initializers.VarianceScaling(critic_init_scale, "fan_avg", "truncated_normal"),
            w_init=w_init,
            b_init=b_init,
            activation=jax.nn.relu,
            # activation=jax.nn.tanh,
            activate_final=True)
        h = _mlp(x)
        _linear = hk.Linear(critic_output_dim, w_init=w_init, b_init=b_init)
        v = _linear(h)
        return v, h

    # def small_critic(x):
    #   # this one is for exploring maximal parameterization
    #   width = 256
    #   x = hk.Linear(
    #       width,
    #       w_init=hk.initializers.VarianceScaling(scale=1.0, mode='fan_out', distribution='truncated_normal'),
    #       b_init=hk.initializers.VarianceScaling(scale=0.05, mode='fan_out', distribution='truncated_normal'))(x)
    #   x = x * (float(width) ** 0.5)
    #   x = jax.nn.relu(x)
    #   x = hk.Linear(
    #       width,
    #       w_init=hk.initializers.VarianceScaling(scale=1.0, mode='fan_in', distribution='truncated_normal'),
    #       b_init=hk.initializers.VarianceScaling(scale=0.05, mode='fan_in', distribution='truncated_normal'),)(x)
    #   x = jax.nn.relu(x)
    #   x = hk.Linear(
    #       width,
    #       w_init=hk.initializers.VarianceScaling(scale=1.0, mode='fan_in', distribution='truncated_normal'),
    #       b_init=hk.initializers.VarianceScaling(scale=0.05, mode='fan_in', distribution='truncated_normal'),)(x)
    #   x = jax.nn.relu(x)
    #   h = x
    #   x = hk.Linear(
    #       1,
    #       w_init=hk.initializers.VarianceScaling(scale=1.0, mode='fan_in', distribution='truncated_normal'),
    #       b_init=hk.initializers.VarianceScaling(scale=0.05, mode='fan_in', distribution='truncated_normal'),)(x)
    #   x = x / (float(width) ** 0.5)
    #   return x, h

    # def large_critic(x):
    #   # inspired by the ones used in RL Unplugged, but smaller hidden layer sizes
    #   hid_dim = 256
    #   _encoder = hk.Linear(hid_dim, w_init=w_init, b_init=b_init)
    #   x = _encoder(x)
    #   for i in range(4):
    #     x = network_utils.ResidualLayerNormBlock(
    #         [hid_dim, hid_dim],
    #         activation=jax.nn.relu,
    #         w_init=w_init,
    #         b_init=b_init,)(x)
    #   h = hk.Linear(hid_dim, w_init=w_init, b_init=b_init)(x)
    #   v = hk.Linear(critic_output_dim, w_init=w_init, b_init=b_init)(h)
    #   return v, h
    def large_critic(x):
        # inspired by the ones used in RL Unplugged
        x = hk.Sequential([
            hk.Linear(400, w_init=rlu_uniform_initializer),
            hk.LayerNorm(axis=-1, create_scale=True, create_offset=True),
            jax.lax.tanh,
        ])(x)
        x = hk.Linear(1024, w_init=rlu_uniform_initializer)(x)
        for i in range(4):
            x = network_utils.ResidualLayerNormBlock(
                [1024, 1024],
                activation=jax.nn.relu,
                w_init=rlu_uniform_initializer,
            )(x)
        h = x
        # v = hk.Linear(1, w_init=rlu_uniform_initializer)(h)
        # v = hk.Linear(critic_output_dim)(h)
        all_vs = []
        for _ in range(critic_output_dim):
            head_v = hk.Linear(256, w_init=rlu_uniform_initializer)(h)
            head_v = jax.nn.relu(head_v)
            head_v = hk.Linear(1, w_init=rlu_uniform_initializer)(head_v)
            all_vs.append(head_v)
        v = jnp.concatenate(all_vs, axis=-1)
        return v, h

    # def _critic_fn(obs, action):
    def _all_critic_stuff(obs, action):
        # for matching Ilya's codebase
        # relu_orthogonal = hk.initializers.Orthogonal(scale=2.0**0.5)
        # near_zero_orthogonal = hk.initializers.Orthogonal(1e-2)
        # def _cn(x):
        #   for hid_dim in critic_hidden_layer_sizes:
        #     x = hk.Linear(hid_dim, w_init=relu_orthogonal, b_init=jnp.zeros)(x)
        #     x = jax.nn.relu(x)
        #   x = hk.Linear(1, w_init=near_zero_orthogonal, b_init=jnp.zeros)(x)
        #   return x
        # input_ = jnp.concatenate([obs, action], axis=-1)
        # if use_double_q:
        #   value1 = _cn(input_)
        #   value2 = _cn(input_)
        #   return jnp.concatenate([value1, value2], axis=-1)
        # else:
        #   return _cn(input_)

        # w_init = hk.initializers.VarianceScaling(2.0, 'fan_in', 'uniform')
        # b_init = jnp.zeros

        #####################################
        input_ = jnp.concatenate([obs, action], axis=-1)

        if ensemble_method == 'tree_deep_ensembles':
            critic_network_builder = network_utils.build_tree_deep_ensemble_critic(
                w_init, b_init, use_double_q)
        elif ensemble_method == 'efficient_tree_deep_ensembles':
            critic_network_builder = network_utils.build_efficient_tree_deep_ensemble_critic(
                w_init, b_init, use_double_q)
        else:
            # for standard d4rl architecture
            critic_network_builder = small_critic
            # for larger architecture inspired by rl unplugged
            # critic_network_builder = large_critic

        value1, h1 = critic_network_builder(input_)
        if ensemble_method in [
                'mimo', 'tree_deep_ensembles', 'efficient_tree_deep_ensembles'
        ]:
            value1 = jnp.reshape(value1, [-1, ensemble_size, 1])

        if use_double_q:
            value2, h2 = critic_network_builder(input_)
            if ensemble_method in [
                    'mimo', 'tree_deep_ensembles',
                    'efficient_tree_deep_ensembles'
            ]:
                value2 = jnp.reshape(value2, [-1, ensemble_size, 1])
            return jnp.concatenate([value1, value2],
                                   axis=-1), jnp.concatenate([h1, h2], axis=-1)
        else:
            return value1, h1

    def get_particular_critic_init(w_init, b_init, key, obs, act):
        def _critic_with_particular_init(obs, action):
            raise NotImplementedError(
                'Not implemented for MIMO, Not implemented for new version that also returns h1, h2'
            )
            network1 = hk.Sequential([
                hk.nets.MLP(list(critic_hidden_layer_sizes) + [1],
                            w_init=w_init,
                            b_init=b_init,
                            activation=jax.nn.relu,
                            activate_final=False),
            ])
            input_ = jnp.concatenate([obs, action], axis=-1)
            value1 = network1(input_)
            if use_double_q:
                network2 = hk.Sequential([
                    hk.nets.MLP(list(critic_hidden_layer_sizes) + [1],
                                w_init=w_init,
                                b_init=b_init,
                                activation=jax.nn.relu,
                                activate_final=False),
                ])
                value2 = network2(input_)
                return jnp.concatenate([value1, value2], axis=-1)
            else:
                return value1

        init_fn = hk.without_apply_rng(
            hk.transform(_critic_with_particular_init, apply_rng=True)).init
        return init_fn(key, obs, act)

    kernel_fn = None
    if build_kernel_fn:
        layers = []
        for hid_dim in critic_hidden_layer_sizes:
            # W_std = 1.5
            W_std = 2.0
            layers += [
                stax.Dense(hid_dim, W_std=W_std, b_std=0.05),
                stax.Relu()
            ]
        layers += [stax.Dense(1, W_std=W_std, b_std=0.05)]
        nt_init_fn, nt_apply_fn, nt_kernel_fn = stax.serial(*layers)
        kernel_fn = jax.jit(nt_kernel_fn, static_argnums=(2, ))

    if img_encoder_fn is not None:
        # _actor_fn = bimanual_sweep.policy_on_encoder_v0(num_dimensions)
        # _critic_fn = bimanual_sweep.critic_on_encoder_v0(use_double_q=use_double_q)
        _actor_fn = bimanual_sweep.policy_on_encoder_v1(num_dimensions)
        raise NotImplementedError(
            'Need to handle the returning of h1, h2 with new version of all_critic_stuff'
        )
        _critic_fn = bimanual_sweep.critic_on_encoder_v1(
            use_double_q=use_double_q)

    def _simclr_encoder(h):
        # return hk.nets.MLP(
        #     [256, 128],
        #     # [256, 256, 256],
        #     w_init=w_init,
        #     # b_init=b_init, # b_init should not be set when not using bias
        #     with_bias=False,
        #     activation=jax.nn.relu,
        #     activate_final=False)(h)

        # IF YOU CHANGE THIS AND USE SASS, YOU NEED TO FIX THE SASS ENCODER OPTIM STEP
        return h  # i.e. no encoder (sometimes referred to as "projection")

    policy = hk.without_apply_rng(hk.transform(_actor_fn, apply_rng=True))
    # critic = hk.without_apply_rng(hk.transform(_critic_fn, apply_rng=True))
    all_critic_stuff = hk.without_apply_rng(
        hk.transform(_all_critic_stuff, apply_rng=True))
    critic_init = all_critic_stuff.init
    critic_apply = lambda p, obs, act: all_critic_stuff.apply(p, obs, act)[0]
    critic_repr = lambda p, obs, act: all_critic_stuff.apply(p, obs, act)[1]

    simclr_encoder = hk.without_apply_rng(
        hk.transform(_simclr_encoder, apply_rng=True))

    # Create dummy observations and actions to create network parameters.
    dummy_action = utils.zeros_like(spec.actions)
    dummy_obs = utils.zeros_like(spec.observations)
    dummy_action = utils.add_batch_dim(dummy_action)
    dummy_obs = utils.add_batch_dim(dummy_obs)
    tile_shape = [1 for _ in range(dummy_action.ndim)]
    tile_shape[0] = 256
    dummy_action = jnp.tile(dummy_action, tile_shape)
    tile_shape = [1 for _ in range(dummy_obs.ndim)]
    tile_shape[0] = 256
    dummy_obs = jnp.tile(dummy_obs, tile_shape)

    if img_encoder_fn is not None:
        img_encoder = hk.without_apply_rng(
            hk.transform(img_encoder_fn, apply_rng=True))
        key = jax.random.PRNGKey(seed=42)
        temp_encoder_params = img_encoder.init(key, dummy_obs['state_image'])
        dummy_hidden = img_encoder.apply(temp_encoder_params,
                                         dummy_obs['state_image'])
        img_encoder_network = networks_lib.FeedForwardNetwork(
            lambda key: img_encoder.init(key, dummy_hidden), img_encoder.apply)
        dummy_encoded_input = dict(
            state_image=dummy_hidden,
            state_dense=dummy_obs['state_dense'],
        )
    else:
        img_encoder_fn = None
        dummy_encoded_input = dummy_obs
        img_encoder_network = None

    critic_dummy_encoded_input = dummy_encoded_input
    critic_dummy_action = dummy_action
    if ensemble_method == 'mimo':
        if mimo_using_obs_tile:
            # if using the version where we are also tiling the obs
            tile_array = [1] * len(
                critic_dummy_encoded_input.shape)  # type: ignore
            tile_array[-1] = ensemble_size
            critic_dummy_encoded_input = jnp.tile(critic_dummy_encoded_input,
                                                  tile_array)

        if mimo_using_act_tile:
            # if using the version where we are also tiling the acts
            tile_array = [1] * len(critic_dummy_action.shape)
            tile_array[-1] = ensemble_size
            critic_dummy_action = jnp.tile(critic_dummy_action, tile_array)

    temp_critic_params = critic_init(jax.random.PRNGKey(42),
                                     critic_dummy_encoded_input,
                                     critic_dummy_action)
    dummy_critic_repr = critic_repr(temp_critic_params,
                                    critic_dummy_encoded_input,
                                    critic_dummy_action)

    # mixture_sample = build_gaussian_mixture_sample(num_dimensions, NUM_MIXTURE_COMPONENTS, eval_mode=False)
    # mixture_sample_eval = build_gaussian_mixture_sample(num_dimensions, NUM_MIXTURE_COMPONENTS, eval_mode=True)
    # mixture_log_prob = build_gaussian_mixture_log_prob(num_dimensions, NUM_MIXTURE_COMPONENTS)

    return MSGNetworks(
        policy_network=networks_lib.FeedForwardNetwork(
            lambda key: policy.init(key, dummy_encoded_input), policy.apply),
        q_network=networks_lib.FeedForwardNetwork(
            lambda key: critic_init(key, critic_dummy_encoded_input,
                                    critic_dummy_action), critic_apply),
        log_prob=lambda params, actions: params.log_prob(actions),
        sample=lambda params, key: params.sample(seed=key),
        # sample_eval=lambda params, key: params.mode(),
        sample_eval=lambda params, key: params.sample(seed=key),

        # log_prob=mixture_log_prob,
        # sample=mixture_sample,
        # # sample_eval=lambda params, key: params.mode(),
        # sample_eval=mixture_sample_eval,
        img_encoder=img_encoder_network,
        kernel_fn=kernel_fn,
        get_particular_critic_init=lambda
        w_init, b_init, key: get_particular_critic_init(
            w_init, b_init, key, dummy_encoded_input, dummy_action),
        get_critic_repr=critic_repr,
        simclr_encoder=networks_lib.FeedForwardNetwork(
            lambda key: simclr_encoder.init(key, dummy_critic_repr),
            simclr_encoder.apply),
    )
Пример #12
0
if model_name == 'Myrtle':
    init_fn, apply_fn, kernel_fn = stax.serial(stax.Conv(512, (3, 3), strides=(1, 1), W_std=W_std, b_std=b_std, padding='SAME'),\
     stax.Relu(),\
     stax.Conv(512, (3, 3), strides=(1, 1), W_std=W_std, b_std=b_std, padding='SAME'),\
     stax.Relu(),\
     stax.AvgPool((2, 2), strides=(2, 2), padding='VALID'),\
     stax.Conv(512, (3, 3), strides=(1, 1), W_std=W_std, b_std=b_std, padding='SAME'),\
     stax.Relu(),\
     stax.AvgPool((2, 2), strides=(2, 2), padding='VALID'),\
     stax.Conv(512, (3, 3), strides=(1, 1), W_std=W_std, b_std=b_std, padding='SAME'),\
     stax.Relu(),\
     stax.AvgPool((2, 2), strides=(2, 2), padding='VALID'),\
     stax.AvgPool((2, 2), strides=(2, 2), padding='VALID'),\
     stax.AvgPool((2, 2), strides=(2, 2), padding='VALID'),\
     stax.Flatten(),\
     stax.Dense(10, W_std, b_std))
else:
    raise Exception('Invalid Input Error')

apply_fn = jit(apply_fn)
kernel_fn = jit(kernel_fn, static_argnums=(2, ))
kernel_fn = nt.batch(kernel_fn, batch_size=20)

X1 = X[row_id * m:(row_id + 1) * m, :, :, :]
assert X1.shape[0] == m and X1.shape[1] == 32 and X1.shape[
    2] == 32 and X1.shape[3] == 3

# Training kernel
K = onp.zeros((m, n), dtype=onp.float32)
col_count = onp.int(n / m)
for col_id in range(row_id, col_count):
Пример #13
0
  def test_fan_in_fc(self, same_inputs, axis, n_branches, get, branch_in,
                     fan_in_mode):
    if fan_in_mode in ['FanInSum', 'FanInProd']:
      if axis != 0:
        raise absltest.SkipTest('`FanInSum` and `FanInProd` are skipped when '
                                'axis != 0.')
      axis = None
    if (fan_in_mode == 'FanInSum' or
        axis == 0) and branch_in == 'dense_after_branch_in':
      raise absltest.SkipTest('`FanInSum` and `FanInConcat(0)` '
                              'require `is_gaussian`.')

    if ((axis == 1 or fan_in_mode == 'FanInProd') and
        branch_in == 'dense_before_branch_in'):
      raise absltest.SkipTest(
          '`FanInConcat` or `FanInProd` on feature axis requires a dense layer '
          'after concatenation or Hadamard product.')
    if fan_in_mode == 'FanInSum':
      fan_in_layer = stax.FanInSum()
    elif fan_in_mode == 'FanInProd':
      fan_in_layer = stax.FanInProd()
    else:
      fan_in_layer = stax.FanInConcat(axis)

    if n_branches != 2:
      test_utils.skip_test(self)

    key = random.PRNGKey(1)
    X0_1 = np.cos(random.normal(key, (4, 3)))
    X0_2 = None if same_inputs else random.normal(key, (8, 3))

    width = 1024
    n_samples = 256 * 2

    if default_backend() == 'tpu':
      tol = 0.07
    else:
      tol = 0.02

    dense = stax.Dense(width, 1.25, 0.1)
    input_layers = [dense,
                    stax.FanOut(n_branches)]

    branches = []
    for b in range(n_branches):
      branch_layers = [FanInTest._get_phi(b)]
      for i in range(b):
        multiplier = 1 if axis not in (1, -1) else (1 + 0.25 * i)
        branch_layers += [
            stax.Dense(int(width * multiplier), 1. + 2 * i, 0.5 + i),
            FanInTest._get_phi(i)]

      if branch_in == 'dense_before_branch_in':
        branch_layers += [dense]
      branches += [stax.serial(*branch_layers)]

    output_layers = [
        fan_in_layer,
        stax.Relu()
    ]
    if branch_in == 'dense_after_branch_in':
      output_layers.insert(1, dense)

    nn = stax.serial(*(input_layers + [stax.parallel(*branches)] +
                       output_layers))

    if get == 'nngp':
      init_fn, apply_fn, kernel_fn = nn
    elif get == 'ntk':
      init_fn, apply_fn, kernel_fn = stax.serial(nn, stax.Dense(1, 1.25, 0.5))
    else:
      raise ValueError(get)

    kernel_fn_mc = nt.monte_carlo_kernel_fn(
        init_fn, apply_fn, key, n_samples,
        device_count=0 if axis in (0, -2) else -1,
        implementation=2,
        vmap_axes=None if axis in (0, -2) else 0,
    )

    exact = kernel_fn(X0_1, X0_2, get=get)
    empirical = kernel_fn_mc(X0_1, X0_2, get=get)
    test_utils.assert_close_matrices(self, empirical, exact, tol)
Пример #14
0
  def test_fan_in_conv(self,
                       same_inputs,
                       axis,
                       n_branches,
                       get,
                       branch_in,
                       readout,
                       fan_in_mode):
    test_utils.skip_test(self)
    if fan_in_mode in ['FanInSum', 'FanInProd']:
      if axis != 0:
        raise absltest.SkipTest('`FanInSum` and `FanInProd()` are skipped when '
                                'axis != 0.')
      axis = None
    if (fan_in_mode == 'FanInSum' or
        axis in [0, 1, 2]) and branch_in == 'dense_after_branch_in':
      raise absltest.SkipTest('`FanInSum` and `FanInConcat(0/1/2)` '
                              'require `is_gaussian`.')

    if ((axis == 3 or fan_in_mode == 'FanInProd') and
        branch_in == 'dense_before_branch_in'):
      raise absltest.SkipTest('`FanInConcat` or `FanInProd` on feature axis '
                              'requires a dense layer after concatenation '
                              'or Hadamard product.')

    if fan_in_mode == 'FanInSum':
      fan_in_layer = stax.FanInSum()
    elif fan_in_mode == 'FanInProd':
      fan_in_layer = stax.FanInProd()
    else:
      fan_in_layer = stax.FanInConcat(axis)

    key = random.PRNGKey(1)
    X0_1 = random.normal(key, (2, 5, 6, 3))
    X0_2 = None if same_inputs else random.normal(key, (3, 5, 6, 3))

    if default_backend() == 'tpu':
      width = 2048
      n_samples = 1024
      tol = 0.02
    else:
      width = 1024
      n_samples = 512
      tol = 0.01

    conv = stax.Conv(out_chan=width,
                     filter_shape=(3, 3),
                     padding='SAME',
                     W_std=1.25,
                     b_std=0.1)

    input_layers = [conv,
                    stax.FanOut(n_branches)]

    branches = []
    for b in range(n_branches):
      branch_layers = [FanInTest._get_phi(b)]
      for i in range(b):
        multiplier = 1 if axis not in (3, -1) else (1 + 0.25 * i)
        branch_layers += [
            stax.Conv(
                out_chan=int(width * multiplier),
                filter_shape=(i + 1, 4 - i),
                padding='SAME',
                W_std=1.25 + i,
                b_std=0.1 + i),
            FanInTest._get_phi(i)]

      if branch_in == 'dense_before_branch_in':
        branch_layers += [conv]
      branches += [stax.serial(*branch_layers)]

    output_layers = [
        fan_in_layer,
        stax.Relu(),
        stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten()
    ]
    if branch_in == 'dense_after_branch_in':
      output_layers.insert(1, conv)

    nn = stax.serial(*(input_layers + [stax.parallel(*branches)] +
                       output_layers))

    init_fn, apply_fn, kernel_fn = stax.serial(
        nn, stax.Dense(1 if get == 'ntk' else width, 1.25, 0.5))

    kernel_fn_mc = nt.monte_carlo_kernel_fn(
        init_fn,
        apply_fn,
        key,
        n_samples,
        device_count=0 if axis in (0, -4) else -1,
        implementation=2,
        vmap_axes=None if axis in (0, -4) else 0,
    )

    exact = kernel_fn(X0_1, X0_2, get=get)
    empirical = kernel_fn_mc(X0_1, X0_2, get=get)
    test_utils.assert_close_matrices(self, empirical, exact, tol)
Пример #15
0
  def test_kwargs(self, do_batch, mode):
    rng = random.PRNGKey(1)

    x_train = random.normal(rng, (8, 7, 10))
    x_test = random.normal(rng, (4, 7, 10))
    y_train = random.normal(rng, (8, 1))

    rng_train, rng_test = random.split(rng, 2)

    pattern_train = random.normal(rng, (8, 7, 7))
    pattern_test = random.normal(rng, (4, 7, 7))

    init_fn, apply_fn, kernel_fn = stax.serial(
        stax.Dense(8),
        stax.Relu(),
        stax.Dropout(rate=0.4),
        stax.Aggregate(),
        stax.GlobalAvgPool(),
        stax.Dense(1)
    )

    kw_dd = dict(pattern=(pattern_train, pattern_train))
    kw_td = dict(pattern=(pattern_test, pattern_train))
    kw_tt = dict(pattern=(pattern_test, pattern_test))

    if mode == 'mc':
      kernel_fn = monte_carlo_kernel_fn(init_fn, apply_fn, rng, 2,
                                        batch_size=2 if do_batch else 0)

    elif mode == 'empirical':
      kernel_fn = empirical_kernel_fn(apply_fn)
      if do_batch:
        raise absltest.SkipTest('Batching of empirical kernel is not '
                                'implemented with keyword arguments.')

      for kw in (kw_dd, kw_td, kw_tt):
        kw.update(dict(params=init_fn(rng, x_train.shape)[1],
                       get=('nngp', 'ntk')))

      kw_dd.update(dict(rng=(rng_train, None)))
      kw_td.update(dict(rng=(rng_test, rng_train)))
      kw_tt.update(dict(rng=(rng_test, None)))

    elif mode == 'analytic':
      if do_batch:
        kernel_fn = batch.batch(kernel_fn, batch_size=2)

    else:
      raise ValueError(mode)

    k_dd = kernel_fn(x_train, None, **kw_dd)
    k_td = kernel_fn(x_test, x_train, **kw_td)
    k_tt = kernel_fn(x_test, None, **kw_tt)

    # Infinite time NNGP/NTK.
    predict_fn_gp = predict.gp_inference(k_dd, y_train)
    out_gp = predict_fn_gp(k_test_train=k_td, nngp_test_test=k_tt.nngp)

    if mode == 'empirical':
      for kw in (kw_dd, kw_td, kw_tt):
        kw.pop('get')

    predict_fn_ensemble = predict.gradient_descent_mse_ensemble(kernel_fn,
                                                                x_train,
                                                                y_train,
                                                                **kw_dd)
    out_ensemble = predict_fn_ensemble(x_test=x_test, compute_cov=True, **kw_tt)
    self.assertAllClose(out_gp, out_ensemble)

    # Finite time NTK test.
    predict_fn_mse = predict.gradient_descent_mse(k_dd.ntk, y_train)
    out_mse = predict_fn_mse(t=1.,
                             fx_train_0=None,
                             fx_test_0=0.,
                             k_test_train=k_td.ntk)
    out_ensemble = predict_fn_ensemble(t=1.,
                                       get='ntk',
                                       x_test=x_test,
                                       compute_cov=False,
                                       **kw_tt)
    self.assertAllClose(out_mse, out_ensemble)

    # Finite time NNGP train.
    predict_fn_mse = predict.gradient_descent_mse(k_dd.nngp, y_train)
    out_mse = predict_fn_mse(t=2.,
                             fx_train_0=0.,
                             fx_test_0=None,
                             k_test_train=k_td.nngp)
    out_ensemble = predict_fn_ensemble(t=2.,
                                       get='nngp',
                                       x_test=None,
                                       compute_cov=False,
                                       **kw_dd)
    self.assertAllClose(out_mse, out_ensemble)