def testGpInference(self): reg = 1e-5 key = random.PRNGKey(1) x_train = random.normal(key, (4, 2)) init_fn, apply_fn, kernel_fn_analytic = stax.serial( stax.Dense(32, 2., 0.5), stax.Relu(), stax.Dense(10, 2., 0.5)) y_train = random.normal(key, (4, 10)) for kernel_fn_is_analytic in [True, False]: if kernel_fn_is_analytic: kernel_fn = kernel_fn_analytic else: _, params = init_fn(key, x_train.shape) kernel_fn_empirical = empirical.empirical_kernel_fn(apply_fn) def kernel_fn(x1, x2, get): return kernel_fn_empirical(x1, x2, get, params) for get in [ None, 'nngp', 'ntk', ('nngp', ), ('ntk', ), ('nngp', 'ntk'), ('ntk', 'nngp') ]: k_dd = kernel_fn(x_train, None, get) gp_inference = predict.gp_inference(k_dd, y_train, diag_reg=reg) gd_ensemble = predict.gradient_descent_mse_ensemble( kernel_fn, x_train, y_train, diag_reg=reg) for x_test in [None, 'x_test']: x_test = None if x_test is None else random.normal( key, (8, 2)) k_td = None if x_test is None else kernel_fn( x_test, x_train, get) for compute_cov in [True, False]: with self.subTest( kernel_fn_is_analytic=kernel_fn_is_analytic, get=get, x_test=x_test if x_test is None else 'x_test', compute_cov=compute_cov): if compute_cov: nngp_tt = (True if x_test is None else kernel_fn(x_test, None, 'nngp')) else: nngp_tt = None out_ens = gd_ensemble(None, x_test, get, compute_cov) out_ens_inf = gd_ensemble(np.inf, x_test, get, compute_cov) self._assertAllClose(out_ens_inf, out_ens, 0.08) if (get is not None and 'nngp' not in get and compute_cov and k_td is not None): with self.assertRaises(ValueError): out_gp_inf = gp_inference( get=get, k_test_train=k_td, nngp_test_test=nngp_tt) else: out_gp_inf = gp_inference( get=get, k_test_train=k_td, nngp_test_test=nngp_tt) self.assertAllClose(out_ens, out_gp_inf)
def net(N_out): return stax.parallel( stax.Dense(N_out), stax.parallel(stax.Dense(N_out + 1), stax.Dense(N_out + 2)))
def main(*args, use_dummy_data: bool = False, **kwargs) -> None: # Mask all padding with this value. mask_constant = 100. if use_dummy_data: x_train, y_train, x_test, y_test = _get_dummy_data(mask_constant) else: # Build data pipelines. print('Loading IMDb data.') x_train, y_train, x_test, y_test = datasets.get_dataset( name='imdb_reviews', n_train=_TRAIN_SIZE, n_test=_TEST_SIZE, do_flatten_and_normalize=False, data_dir=_IMDB_PATH, input_key='text') # Embed words and pad / truncate sentences to a fixed size. x_train, x_test = datasets.embed_glove( xs=[x_train, x_test], glove_path=_GLOVE_PATH, max_sentence_length=_MAX_SENTENCE_LENGTH, mask_constant=mask_constant) # Build the infinite network. # Not using the finite model, hence width is set to 1 everywhere. _, _, kernel_fn = stax.serial( stax.Conv(out_chan=1, filter_shape=(9, ), strides=(1, ), padding='VALID'), stax.Relu(), stax.GlobalSelfAttention(n_chan_out=1, n_chan_key=1, n_chan_val=1, pos_emb_type='SUM', W_pos_emb_std=1., pos_emb_decay_fn=lambda d: 1 / (1 + d**2), n_heads=1), stax.Relu(), stax.GlobalAvgPool(), stax.Dense(out_dim=1)) # Optionally, compute the kernel in batches, in parallel. kernel_fn = nt.batch(kernel_fn, device_count=-1, batch_size=_BATCH_SIZE) start = time.time() # Bayesian and infinite-time gradient descent inference with infinite network. predict = nt.predict.gradient_descent_mse_ensemble( kernel_fn=kernel_fn, x_train=x_train, y_train=y_train, diag_reg=1e-6, mask_constant=mask_constant) fx_test_nngp, fx_test_ntk = predict(x_test=x_test, get=('nngp', 'ntk')) fx_test_nngp.block_until_ready() fx_test_ntk.block_until_ready() duration = time.time() - start print(f'Kernel construction and inference done in {duration} seconds.') # Print out accuracy and loss for infinite network predictions. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss) util.print_summary('NTK test', y_test, fx_test_ntk, None, loss)
def testTrainedEnsemblePredCov(self, train_shape, test_shape, network, out_logits): training_steps = 1000 learning_rate = 0.1 ensemble_size = 1024 init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(128, W_std=1.2, b_std=0.05), stax.Erf(), stax.Dense(out_logits, W_std=1.2, b_std=0.05)) opt_init, opt_update, get_params = optimizers.sgd(learning_rate) opt_update = jit(opt_update) key, x_test, x_train, y_train = self._get_inputs( out_logits, test_shape, train_shape) predict_fn_mse_ens = predict.gradient_descent_mse_ensemble( kernel_fn, x_train, y_train, learning_rate=learning_rate, diag_reg=0.) train = (x_train, y_train) ensemble_key = random.split(key, ensemble_size) loss = jit(lambda params, x, y: 0.5 * np.mean( (apply_fn(params, x) - y)**2)) grad_loss = jit(lambda state, x, y: grad(loss) (get_params(state), x, y)) def train_network(key): _, params = init_fn(key, (-1, ) + train_shape[1:]) opt_state = opt_init(params) for i in range(training_steps): opt_state = opt_update(i, grad_loss(opt_state, *train), opt_state) return get_params(opt_state) params = vmap(train_network)(ensemble_key) rtol = 0.08 for x in [None, 'x_test']: with self.subTest(x=x): x = x if x is None else x_test x_fin = x_train if x is None else x_test ensemble_fx = vmap(apply_fn, (0, None))(params, x_fin) mean_emp = np.mean(ensemble_fx, axis=0, keepdims=True) mean_subtracted = ensemble_fx - mean_emp cov_emp = np.einsum( 'ijk,ilk->jl', mean_subtracted, mean_subtracted, optimize=True) / (mean_subtracted.shape[0] * mean_subtracted.shape[-1]) ntk = predict_fn_mse_ens(training_steps, x, 'ntk', compute_cov=True) self._assertAllClose(mean_emp, ntk.mean, rtol) self._assertAllClose(cov_emp, ntk.covariance, rtol)
def test_fan_in_fc(self, same_inputs, axis, n_branches, get, branch_in): if axis in (None, 0) and branch_in == 'dense_after_branch_in': raise jtu.SkipTest('`FanInSum` and `FanInConcat(0)` ' 'require `is_gaussian`.') if axis == 1 and branch_in == 'dense_before_branch_in': raise jtu.SkipTest( '`FanInConcat` on feature axis requires a dense layer' 'after concatenation.') key = random.PRNGKey(1) X0_1 = random.normal(key, (10, 20)) X0_2 = None if same_inputs else random.normal(key, (8, 20)) if xla_bridge.get_backend().platform == 'tpu': width = 2048 n_samples = 1024 tol = 0.02 else: width = 1024 n_samples = 256 tol = 0.01 dense = stax.Dense(width, 1.25, 0.1) input_layers = [dense, stax.FanOut(n_branches)] branches = [] for b in range(n_branches): branch_layers = [FanInTest._get_phi(b)] for i in range(b): branch_layers += [ stax.Dense(width, 1. + 2 * i, 0.5 + i), FanInTest._get_phi(i) ] if branch_in == 'dense_before_branch_in': branch_layers += [dense] branches += [stax.serial(*branch_layers)] output_layers = [ stax.FanInSum() if axis is None else stax.FanInConcat(axis), stax.Relu() ] if branch_in == 'dense_after_branch_in': output_layers.insert(1, dense) nn = stax.serial(*(input_layers + [stax.parallel(*branches)] + output_layers)) if get == 'nngp': init_fn, apply_fn, kernel_fn = nn elif get == 'ntk': init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(1, 1.25, 0.5)) else: raise ValueError(get) kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, n_samples, device_count=0) exact = kernel_fn(X0_1, X0_2, get=get) empirical = kernel_fn_mc(X0_1, X0_2, get=get) empirical = empirical.reshape(exact.shape) test_utils.assert_close_matrices(self, empirical, exact, tol)
def test_fan_in_conv(self, same_inputs, axis, n_branches, get, branch_in, readout): if xla_bridge.get_backend().platform == 'cpu': raise jtu.SkipTest('Not running CNNs on CPU to save time.') if axis in (None, 0, 1, 2) and branch_in == 'dense_after_branch_in': raise jtu.SkipTest('`FanInSum` and `FanInConcat(0/1/2)` ' 'require `is_gaussian`.') if axis == 3 and branch_in == 'dense_before_branch_in': raise jtu.SkipTest( '`FanInConcat` on feature axis requires a dense layer ' 'after concatenation.') key = random.PRNGKey(1) X0_1 = random.normal(key, (2, 5, 6, 3)) X0_2 = None if same_inputs else random.normal(key, (3, 5, 6, 3)) if xla_bridge.get_backend().platform == 'tpu': width = 2048 n_samples = 1024 tol = 0.02 else: width = 1024 n_samples = 512 tol = 0.01 conv = stax.Conv(out_chan=width, filter_shape=(3, 3), padding='SAME', W_std=1.25, b_std=0.1) input_layers = [conv, stax.FanOut(n_branches)] branches = [] for b in range(n_branches): branch_layers = [FanInTest._get_phi(b)] for i in range(b): branch_layers += [ stax.Conv(out_chan=width, filter_shape=(i + 1, 4 - i), padding='SAME', W_std=1.25 + i, b_std=0.1 + i), FanInTest._get_phi(i) ] if branch_in == 'dense_before_branch_in': branch_layers += [conv] branches += [stax.serial(*branch_layers)] output_layers = [ stax.FanInSum() if axis is None else stax.FanInConcat(axis), stax.Relu(), stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten() ] if branch_in == 'dense_after_branch_in': output_layers.insert(1, conv) nn = stax.serial(*(input_layers + [stax.parallel(*branches)] + output_layers)) init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(1 if get == 'ntk' else width, 1.25, 0.5)) kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if axis in (0, -4) else -1) exact = kernel_fn(X0_1, X0_2, get=get) empirical = kernel_fn_mc(X0_1, X0_2, get=get) empirical = empirical.reshape(exact.shape) test_utils.assert_close_matrices(self, empirical, exact, tol)
def testTrainedEnsemblePredCov(self, train_shape, test_shape, network, out_logits): if xla_bridge.get_backend().platform == 'gpu' and config.read( 'jax_enable_x64'): raise jtu.SkipTest('Not running GPU x64 to save time.') training_steps = 5000 learning_rate = 1.0 ensemble_size = 50 init_fn, apply_fn, ker_fn = stax.serial( stax.Dense(1024, W_std=1.2, b_std=0.05), stax.Erf(), stax.Dense(out_logits, W_std=1.2, b_std=0.05)) opt_init, opt_update, get_params = optimizers.sgd(learning_rate) opt_update = jit(opt_update) key = random.PRNGKey(0) key, = random.split(key, 1) key, split = random.split(key) x_train = np.cos(random.normal(split, train_shape)) key, split = random.split(key) y_train = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) train = (x_train, y_train) key, split = random.split(key) x_test = np.cos(random.normal(split, test_shape)) ensemble_key = random.split(key, ensemble_size) loss = jit(lambda params, x, y: 0.5 * np.mean( (apply_fn(params, x) - y)**2)) grad_loss = jit(lambda state, x, y: grad(loss) (get_params(state), x, y)) def train_network(key): _, params = init_fn(key, (-1, ) + train_shape[1:]) opt_state = opt_init(params) for i in range(training_steps): opt_state = opt_update(i, grad_loss(opt_state, *train), opt_state) return get_params(opt_state) params = vmap(train_network)(ensemble_key) ensemble_fx = vmap(apply_fn, (0, None))(params, x_test) ensemble_loss = vmap(loss, (0, None, None))(params, x_train, y_train) ensemble_loss = np.mean(ensemble_loss) self.assertLess(ensemble_loss, 1e-5, True) mean_emp = np.mean(ensemble_fx, axis=0) mean_subtracted = ensemble_fx - mean_emp cov_emp = np.einsum( 'ijk,ilk->jl', mean_subtracted, mean_subtracted, optimize=True) / ( mean_subtracted.shape[0] * mean_subtracted.shape[-1]) reg = 1e-7 ntk_predictions = predict.gp_inference(ker_fn, x_train, y_train, x_test, 'ntk', reg, compute_cov=True) self.assertAllClose(mean_emp, ntk_predictions.mean, True, RTOL, ATOL) self.assertAllClose(cov_emp, ntk_predictions.covariance, True, RTOL, ATOL)
def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.mnist(FLAGS.train_size, FLAGS.test_size) # x_train import numpy # numpy.argmax(y_train,1)%2 # y_train_tmp = numpy.zeros((y_train.shape[0],2)) # y_train_tmp[np.arange(y_train.shape[0]),numpy.argmax(y_train,1)%2] = 1 # y_train = y_train_tmp # y_test_tmp = numpy.zeros((y_test.shape[0],2)) # y_test_tmp[np.arange(y_train.shape[0]),numpy.argmax(y_test,1)%2] = 1 # y_test = y_test_tmp y_train_tmp = numpy.argmax(y_train, 1) % 2 y_train = np.expand_dims(y_train_tmp, 1) y_test_tmp = numpy.argmax(y_test, 1) % 2 y_test = np.expand_dims(y_test_tmp, 1) # print(y_train) # Build the network # init_fn, apply_fn, _ = stax.serial( # stax.Dense(2048, 1., 0.05), # # stax.Erf(), # stax.Relu(), # stax.Dense(2048, 1., 0.05), # # stax.Erf(), # stax.Relu(), # stax.Dense(10, 1., 0.05)) init_fn, apply_fn, _ = stax.serial(stax.Dense(2048, 1., 0.05), stax.Erf(), stax.Dense(1, 1., 0.05)) # key = random.PRNGKey(0) randnnn = numpy.random.random_integers(np.iinfo(np.int32).min, high=np.iinfo(np.int32).max, size=2)[0] key = random.PRNGKey(randnnn) _, params = init_fn(key, (-1, 784)) # params # Create and initialize an optimizer. opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate) state = opt_init(params) # state # Create an mse loss function and a gradient function. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y))) # Create an MSE predictor to solve the NTK equation in function space. ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=0) g_dd = ntk(x_train, None, params) g_td = ntk(x_test, x_train, params) predictor = nt.predict.gradient_descent_mse(g_dd, y_train, g_td) # g_dd.shape # Get initial values of the network in function space. fx_train = apply_fn(params, x_train) fx_test = apply_fn(params, x_test) # Train the network. train_steps = int(FLAGS.train_time // FLAGS.learning_rate) print('Training for {} steps'.format(train_steps)) for i in range(train_steps): params = get_params(state) state = opt_apply(i, grad_loss(params, x_train, y_train), state) # Get predictions from analytic computation. print('Computing analytic prediction.') # fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test) fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test) # Print out summary data comparing the linear / nonlinear model. util.print_summary('train', y_train, apply_fn(params, x_train), fx_train, loss) util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)
def weight_space(train_embedding, test_embedding, data_set): init_fn, f, _ = stax.serial( stax.Dense(512, 1., 0.05), stax.Erf(), # 2 denotes 2 type of classes stax.Dense(2, 1., 0.05)) key = random.PRNGKey(0) # (-1, 135), 135 denotes the feature length, here is 9 * 15 = 135 _, params = init_fn(key, (-1, 135)) # Linearize the network about its initial parameters. f_lin = nt.linearize(f, params) # Create and initialize an optimizer for both f and f_lin. opt_init, opt_apply, get_params = optimizers.momentum(1.0, 0.9) opt_apply = jit(opt_apply) state = opt_init(params) state_lin = opt_init(params) # Create a cross-entropy loss function. loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat) # Specialize the loss function to compute gradients for both linearized and # full networks. grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y))) # Train the network. print('Training.') print('Epoch\tLoss\tLinearized Loss') print('------------------------------------------') epoch = 0 # Use whole batch batch_size = 64 train_epochs = 10 steps_per_epoch = 100 for i, (x, y) in enumerate( datasets.mini_batch(train_embedding, data_set['Y_train'], batch_size, train_epochs)): params = get_params(state) state = opt_apply(i, grad_loss(params, x, y), state) params_lin = get_params(state_lin) state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin) if i % steps_per_epoch == 0: print('{}\t{:.4f}\t{:.4f}'.format(epoch, loss(f(params, x), y), loss(f_lin(params_lin, x), y))) epoch += 1 if i / steps_per_epoch == train_epochs: break # Print out summary data comparing the linear / nonlinear model. x, y = train_embedding[:10000], data_set['Y_train'][:10000] util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss) util.print_summary('test', data_set['Y_test'], f(params, test_embedding), f_lin(params_lin, test_embedding), loss)
def main(unused_argv): # Build data and . print('Loading data.') x_train, y_train, x_test, y_test = datasets.get_dataset('mnist', permute_train=True) # Build the network init_fn, f, _ = stax.serial( stax.Dense(2048, 1., 0.05), stax.Erf(), stax.Dense(10, 1., 0.05)) key = random.PRNGKey(0) _, params = init_fn(key, (-1, 784)) # Linearize the network about its initial parameters. f_lin = nt.linearize(f, params) # Create and initialize an optimizer for both f and f_lin. opt_init, opt_apply, get_params = optimizers.momentum(FLAGS.learning_rate, 0.9) opt_apply = jit(opt_apply) state = opt_init(params) state_lin = opt_init(params) # Create a cross-entropy loss function. loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat) # Specialize the loss function to compute gradients for both linearized and # full networks. grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y))) grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y))) # Train the network. print('Training.') print('Epoch\tLoss\tLinearized Loss') print('------------------------------------------') epoch = 0 steps_per_epoch = 50000 // FLAGS.batch_size for i, (x, y) in enumerate(datasets.minibatch( x_train, y_train, FLAGS.batch_size, FLAGS.train_epochs)): params = get_params(state) state = opt_apply(i, grad_loss(params, x, y), state) params_lin = get_params(state_lin) state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin) if i % steps_per_epoch == 0: print('{}\t{:.4f}\t{:.4f}'.format( epoch, loss(f(params, x), y), loss(f_lin(params_lin, x), y))) epoch += 1 # Print out summary data comparing the linear / nonlinear model. x, y = x_train[:10000], y_train[:10000] util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss) util.print_summary( 'test', y_test, f(params, x_test), f_lin(params_lin, x_test), loss)
def make_networks( spec, actor_hidden_layer_sizes=(256, 256), critic_hidden_layer_sizes=(256, 256), init_type='glorot_except_dist', critic_init_scale=1.0, use_double_q=True, img_encoder_fn=None, build_kernel_fn=False, ensemble_method='deep_ensembles', ensemble_size=None, # this is not used for deep ensembles mimo_using_obs_tile=False, mimo_using_act_tile=False, ): """Creates networks used by the agent.""" assert not (build_kernel_fn and (img_encoder_fn is not None)) if ensemble_method not in [ 'deep_ensembles', 'mimo', 'tree_deep_ensembles', 'efficient_tree_deep_ensembles' ]: raise NotImplementedError() num_dimensions = np.prod(spec.actions.shape, dtype=int) if init_type == 'glorot_except_dist': w_init = hk.initializers.VarianceScaling(1.0, "fan_avg", "truncated_normal") b_init = jnp.zeros dist_w_init = hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform') dist_b_init = jnp.zeros elif init_type == 'glorot_also_dist': w_init = hk.initializers.VarianceScaling(1.0, "fan_avg", "truncated_normal") b_init = jnp.zeros dist_w_init = hk.initializers.VarianceScaling(1.0, "fan_avg", "truncated_normal") dist_b_init = jnp.zeros elif init_type == 'he_normal': w_init = hk.initializers.VarianceScaling(2.0, "fan_in", "truncated_normal") b_init = jnp.zeros dist_w_init = w_init dist_b_init = b_init elif init_type == 'Ilya': assert False, 'This is not correct' relu_orthogonal = hk.initializers.Orthogonal(scale=2.0**0.5) near_zero_orthogonal = hk.initializers.Orthogonal(1e-2) w_init = relu_orthogonal b_init = jnp.zeros dist_w_init = near_zero_orthogonal dist_b_init = jnp.zeros else: raise NotImplementedError NUM_MIXTURE_COMPONENTS = 5 # if using gaussian mixtures rlu_uniform_initializer = hk.initializers.VarianceScaling( distribution='uniform', mode='fan_out', scale=0.333) # rlu_uniform_initializer = hk.initializers.VarianceScaling(scale=1e-4) def _actor_fn(obs): # # for matching Ilya's codebase # relu_orthogonal = hk.initializers.Orthogonal(scale=2.0**0.5) # near_zero_orthogonal = hk.initializers.Orthogonal(1e-2) # x = obs # for hid_dim in actor_hidden_layer_sizes: # x = hk.Linear(hid_dim, w_init=relu_orthogonal, b_init=jnp.zeros)(x) # x = jax.nn.relu(x) # dist = networks_lib.NormalTanhDistribution( # num_dimensions, # w_init=near_zero_orthogonal, # b_init=jnp.zeros)(x) # return dist # w_init = hk.initializers.VarianceScaling(2.0, 'fan_in', 'uniform') # b_init = jnp.zeros # PAPER VERSION network = hk.Sequential([ hk.nets.MLP( list(actor_hidden_layer_sizes), # w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'), # w_init=hk.initializers.VarianceScaling(1.0, "fan_avg", "truncated_normal"), w_init=w_init, b_init=b_init, activation=jax.nn.relu, # activation=jax.nn.tanh, activate_final=True), networks_lib.NormalTanhDistribution( num_dimensions, w_init=dist_w_init, b_init=dist_b_init, min_scale=1e-2, ), # networks_lib.MultivariateNormalDiagHead( # num_dimensions, # w_init=w_init, # b_init=b_init), # networks_lib.GaussianMixture( # num_dimensions, # num_components=5, # multivariate=True), # hk.Linear( # NUM_MIXTURE_COMPONENTS + 2 * NUM_MIXTURE_COMPONENTS * num_dimensions, # with_bias=True, # w_init=dist_w_init, # b_init=dist_b_init,), ]) return network(obs) # def _actor_fn(obs): # # inspired by the ones used in RL Unplugged # x = obs # x = hk.Sequential([ # hk.Linear(300, w_init=rlu_uniform_initializer), # hk.LayerNorm(axis=-1, create_scale=True, create_offset=True), # jax.lax.tanh,])(x) # x = hk.Linear(1024, w_init=rlu_uniform_initializer)(x) # for i in range(4): # x = network_utils.ResidualLayerNormBlock( # [1024, 1024], # activation=jax.nn.relu, # w_init=rlu_uniform_initializer,)(x) # # a = hk.Linear( # # NUM_MIXTURE_COMPONENTS + 2 * NUM_MIXTURE_COMPONENTS * num_dimensions, # # with_bias=True, # # w_init=hk.initializers.VarianceScaling(scale=1e-5, mode='fan_in'),)(x) # a = networks_lib.NormalTanhDistribution( # num_dimensions, # w_init=dist_w_init, # b_init=dist_b_init, # min_scale=1e-2,)(x) # # a = networks_lib.MultivariateNormalDiagHead( # # num_dimensions, # # min_scale=1e-2, # # w_init=dist_w_init, # # b_init=dist_b_init,)(x) # return a critic_output_dim = 1 if ensemble_method in [ 'mimo', 'tree_deep_ensembles', 'efficient_tree_deep_ensembles' ]: critic_output_dim = ensemble_size def small_critic(x): # i.e. what people typically use for d4rl benchmark _mlp = hk.nets.MLP( list(critic_hidden_layer_sizes), # w_init=hk.initializers.VarianceScaling(1.0, "fan_avg", "truncated_normal"), # w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'), # w_init=hk.initializers.VarianceScaling(critic_init_scale, "fan_avg", "truncated_normal"), w_init=w_init, b_init=b_init, activation=jax.nn.relu, # activation=jax.nn.tanh, activate_final=True) h = _mlp(x) _linear = hk.Linear(critic_output_dim, w_init=w_init, b_init=b_init) v = _linear(h) return v, h # def small_critic(x): # # this one is for exploring maximal parameterization # width = 256 # x = hk.Linear( # width, # w_init=hk.initializers.VarianceScaling(scale=1.0, mode='fan_out', distribution='truncated_normal'), # b_init=hk.initializers.VarianceScaling(scale=0.05, mode='fan_out', distribution='truncated_normal'))(x) # x = x * (float(width) ** 0.5) # x = jax.nn.relu(x) # x = hk.Linear( # width, # w_init=hk.initializers.VarianceScaling(scale=1.0, mode='fan_in', distribution='truncated_normal'), # b_init=hk.initializers.VarianceScaling(scale=0.05, mode='fan_in', distribution='truncated_normal'),)(x) # x = jax.nn.relu(x) # x = hk.Linear( # width, # w_init=hk.initializers.VarianceScaling(scale=1.0, mode='fan_in', distribution='truncated_normal'), # b_init=hk.initializers.VarianceScaling(scale=0.05, mode='fan_in', distribution='truncated_normal'),)(x) # x = jax.nn.relu(x) # h = x # x = hk.Linear( # 1, # w_init=hk.initializers.VarianceScaling(scale=1.0, mode='fan_in', distribution='truncated_normal'), # b_init=hk.initializers.VarianceScaling(scale=0.05, mode='fan_in', distribution='truncated_normal'),)(x) # x = x / (float(width) ** 0.5) # return x, h # def large_critic(x): # # inspired by the ones used in RL Unplugged, but smaller hidden layer sizes # hid_dim = 256 # _encoder = hk.Linear(hid_dim, w_init=w_init, b_init=b_init) # x = _encoder(x) # for i in range(4): # x = network_utils.ResidualLayerNormBlock( # [hid_dim, hid_dim], # activation=jax.nn.relu, # w_init=w_init, # b_init=b_init,)(x) # h = hk.Linear(hid_dim, w_init=w_init, b_init=b_init)(x) # v = hk.Linear(critic_output_dim, w_init=w_init, b_init=b_init)(h) # return v, h def large_critic(x): # inspired by the ones used in RL Unplugged x = hk.Sequential([ hk.Linear(400, w_init=rlu_uniform_initializer), hk.LayerNorm(axis=-1, create_scale=True, create_offset=True), jax.lax.tanh, ])(x) x = hk.Linear(1024, w_init=rlu_uniform_initializer)(x) for i in range(4): x = network_utils.ResidualLayerNormBlock( [1024, 1024], activation=jax.nn.relu, w_init=rlu_uniform_initializer, )(x) h = x # v = hk.Linear(1, w_init=rlu_uniform_initializer)(h) # v = hk.Linear(critic_output_dim)(h) all_vs = [] for _ in range(critic_output_dim): head_v = hk.Linear(256, w_init=rlu_uniform_initializer)(h) head_v = jax.nn.relu(head_v) head_v = hk.Linear(1, w_init=rlu_uniform_initializer)(head_v) all_vs.append(head_v) v = jnp.concatenate(all_vs, axis=-1) return v, h # def _critic_fn(obs, action): def _all_critic_stuff(obs, action): # for matching Ilya's codebase # relu_orthogonal = hk.initializers.Orthogonal(scale=2.0**0.5) # near_zero_orthogonal = hk.initializers.Orthogonal(1e-2) # def _cn(x): # for hid_dim in critic_hidden_layer_sizes: # x = hk.Linear(hid_dim, w_init=relu_orthogonal, b_init=jnp.zeros)(x) # x = jax.nn.relu(x) # x = hk.Linear(1, w_init=near_zero_orthogonal, b_init=jnp.zeros)(x) # return x # input_ = jnp.concatenate([obs, action], axis=-1) # if use_double_q: # value1 = _cn(input_) # value2 = _cn(input_) # return jnp.concatenate([value1, value2], axis=-1) # else: # return _cn(input_) # w_init = hk.initializers.VarianceScaling(2.0, 'fan_in', 'uniform') # b_init = jnp.zeros ##################################### input_ = jnp.concatenate([obs, action], axis=-1) if ensemble_method == 'tree_deep_ensembles': critic_network_builder = network_utils.build_tree_deep_ensemble_critic( w_init, b_init, use_double_q) elif ensemble_method == 'efficient_tree_deep_ensembles': critic_network_builder = network_utils.build_efficient_tree_deep_ensemble_critic( w_init, b_init, use_double_q) else: # for standard d4rl architecture critic_network_builder = small_critic # for larger architecture inspired by rl unplugged # critic_network_builder = large_critic value1, h1 = critic_network_builder(input_) if ensemble_method in [ 'mimo', 'tree_deep_ensembles', 'efficient_tree_deep_ensembles' ]: value1 = jnp.reshape(value1, [-1, ensemble_size, 1]) if use_double_q: value2, h2 = critic_network_builder(input_) if ensemble_method in [ 'mimo', 'tree_deep_ensembles', 'efficient_tree_deep_ensembles' ]: value2 = jnp.reshape(value2, [-1, ensemble_size, 1]) return jnp.concatenate([value1, value2], axis=-1), jnp.concatenate([h1, h2], axis=-1) else: return value1, h1 def get_particular_critic_init(w_init, b_init, key, obs, act): def _critic_with_particular_init(obs, action): raise NotImplementedError( 'Not implemented for MIMO, Not implemented for new version that also returns h1, h2' ) network1 = hk.Sequential([ hk.nets.MLP(list(critic_hidden_layer_sizes) + [1], w_init=w_init, b_init=b_init, activation=jax.nn.relu, activate_final=False), ]) input_ = jnp.concatenate([obs, action], axis=-1) value1 = network1(input_) if use_double_q: network2 = hk.Sequential([ hk.nets.MLP(list(critic_hidden_layer_sizes) + [1], w_init=w_init, b_init=b_init, activation=jax.nn.relu, activate_final=False), ]) value2 = network2(input_) return jnp.concatenate([value1, value2], axis=-1) else: return value1 init_fn = hk.without_apply_rng( hk.transform(_critic_with_particular_init, apply_rng=True)).init return init_fn(key, obs, act) kernel_fn = None if build_kernel_fn: layers = [] for hid_dim in critic_hidden_layer_sizes: # W_std = 1.5 W_std = 2.0 layers += [ stax.Dense(hid_dim, W_std=W_std, b_std=0.05), stax.Relu() ] layers += [stax.Dense(1, W_std=W_std, b_std=0.05)] nt_init_fn, nt_apply_fn, nt_kernel_fn = stax.serial(*layers) kernel_fn = jax.jit(nt_kernel_fn, static_argnums=(2, )) if img_encoder_fn is not None: # _actor_fn = bimanual_sweep.policy_on_encoder_v0(num_dimensions) # _critic_fn = bimanual_sweep.critic_on_encoder_v0(use_double_q=use_double_q) _actor_fn = bimanual_sweep.policy_on_encoder_v1(num_dimensions) raise NotImplementedError( 'Need to handle the returning of h1, h2 with new version of all_critic_stuff' ) _critic_fn = bimanual_sweep.critic_on_encoder_v1( use_double_q=use_double_q) def _simclr_encoder(h): # return hk.nets.MLP( # [256, 128], # # [256, 256, 256], # w_init=w_init, # # b_init=b_init, # b_init should not be set when not using bias # with_bias=False, # activation=jax.nn.relu, # activate_final=False)(h) # IF YOU CHANGE THIS AND USE SASS, YOU NEED TO FIX THE SASS ENCODER OPTIM STEP return h # i.e. no encoder (sometimes referred to as "projection") policy = hk.without_apply_rng(hk.transform(_actor_fn, apply_rng=True)) # critic = hk.without_apply_rng(hk.transform(_critic_fn, apply_rng=True)) all_critic_stuff = hk.without_apply_rng( hk.transform(_all_critic_stuff, apply_rng=True)) critic_init = all_critic_stuff.init critic_apply = lambda p, obs, act: all_critic_stuff.apply(p, obs, act)[0] critic_repr = lambda p, obs, act: all_critic_stuff.apply(p, obs, act)[1] simclr_encoder = hk.without_apply_rng( hk.transform(_simclr_encoder, apply_rng=True)) # Create dummy observations and actions to create network parameters. dummy_action = utils.zeros_like(spec.actions) dummy_obs = utils.zeros_like(spec.observations) dummy_action = utils.add_batch_dim(dummy_action) dummy_obs = utils.add_batch_dim(dummy_obs) tile_shape = [1 for _ in range(dummy_action.ndim)] tile_shape[0] = 256 dummy_action = jnp.tile(dummy_action, tile_shape) tile_shape = [1 for _ in range(dummy_obs.ndim)] tile_shape[0] = 256 dummy_obs = jnp.tile(dummy_obs, tile_shape) if img_encoder_fn is not None: img_encoder = hk.without_apply_rng( hk.transform(img_encoder_fn, apply_rng=True)) key = jax.random.PRNGKey(seed=42) temp_encoder_params = img_encoder.init(key, dummy_obs['state_image']) dummy_hidden = img_encoder.apply(temp_encoder_params, dummy_obs['state_image']) img_encoder_network = networks_lib.FeedForwardNetwork( lambda key: img_encoder.init(key, dummy_hidden), img_encoder.apply) dummy_encoded_input = dict( state_image=dummy_hidden, state_dense=dummy_obs['state_dense'], ) else: img_encoder_fn = None dummy_encoded_input = dummy_obs img_encoder_network = None critic_dummy_encoded_input = dummy_encoded_input critic_dummy_action = dummy_action if ensemble_method == 'mimo': if mimo_using_obs_tile: # if using the version where we are also tiling the obs tile_array = [1] * len( critic_dummy_encoded_input.shape) # type: ignore tile_array[-1] = ensemble_size critic_dummy_encoded_input = jnp.tile(critic_dummy_encoded_input, tile_array) if mimo_using_act_tile: # if using the version where we are also tiling the acts tile_array = [1] * len(critic_dummy_action.shape) tile_array[-1] = ensemble_size critic_dummy_action = jnp.tile(critic_dummy_action, tile_array) temp_critic_params = critic_init(jax.random.PRNGKey(42), critic_dummy_encoded_input, critic_dummy_action) dummy_critic_repr = critic_repr(temp_critic_params, critic_dummy_encoded_input, critic_dummy_action) # mixture_sample = build_gaussian_mixture_sample(num_dimensions, NUM_MIXTURE_COMPONENTS, eval_mode=False) # mixture_sample_eval = build_gaussian_mixture_sample(num_dimensions, NUM_MIXTURE_COMPONENTS, eval_mode=True) # mixture_log_prob = build_gaussian_mixture_log_prob(num_dimensions, NUM_MIXTURE_COMPONENTS) return MSGNetworks( policy_network=networks_lib.FeedForwardNetwork( lambda key: policy.init(key, dummy_encoded_input), policy.apply), q_network=networks_lib.FeedForwardNetwork( lambda key: critic_init(key, critic_dummy_encoded_input, critic_dummy_action), critic_apply), log_prob=lambda params, actions: params.log_prob(actions), sample=lambda params, key: params.sample(seed=key), # sample_eval=lambda params, key: params.mode(), sample_eval=lambda params, key: params.sample(seed=key), # log_prob=mixture_log_prob, # sample=mixture_sample, # # sample_eval=lambda params, key: params.mode(), # sample_eval=mixture_sample_eval, img_encoder=img_encoder_network, kernel_fn=kernel_fn, get_particular_critic_init=lambda w_init, b_init, key: get_particular_critic_init( w_init, b_init, key, dummy_encoded_input, dummy_action), get_critic_repr=critic_repr, simclr_encoder=networks_lib.FeedForwardNetwork( lambda key: simclr_encoder.init(key, dummy_critic_repr), simclr_encoder.apply), )
if model_name == 'Myrtle': init_fn, apply_fn, kernel_fn = stax.serial(stax.Conv(512, (3, 3), strides=(1, 1), W_std=W_std, b_std=b_std, padding='SAME'),\ stax.Relu(),\ stax.Conv(512, (3, 3), strides=(1, 1), W_std=W_std, b_std=b_std, padding='SAME'),\ stax.Relu(),\ stax.AvgPool((2, 2), strides=(2, 2), padding='VALID'),\ stax.Conv(512, (3, 3), strides=(1, 1), W_std=W_std, b_std=b_std, padding='SAME'),\ stax.Relu(),\ stax.AvgPool((2, 2), strides=(2, 2), padding='VALID'),\ stax.Conv(512, (3, 3), strides=(1, 1), W_std=W_std, b_std=b_std, padding='SAME'),\ stax.Relu(),\ stax.AvgPool((2, 2), strides=(2, 2), padding='VALID'),\ stax.AvgPool((2, 2), strides=(2, 2), padding='VALID'),\ stax.AvgPool((2, 2), strides=(2, 2), padding='VALID'),\ stax.Flatten(),\ stax.Dense(10, W_std, b_std)) else: raise Exception('Invalid Input Error') apply_fn = jit(apply_fn) kernel_fn = jit(kernel_fn, static_argnums=(2, )) kernel_fn = nt.batch(kernel_fn, batch_size=20) X1 = X[row_id * m:(row_id + 1) * m, :, :, :] assert X1.shape[0] == m and X1.shape[1] == 32 and X1.shape[ 2] == 32 and X1.shape[3] == 3 # Training kernel K = onp.zeros((m, n), dtype=onp.float32) col_count = onp.int(n / m) for col_id in range(row_id, col_count):
def test_fan_in_fc(self, same_inputs, axis, n_branches, get, branch_in, fan_in_mode): if fan_in_mode in ['FanInSum', 'FanInProd']: if axis != 0: raise absltest.SkipTest('`FanInSum` and `FanInProd` are skipped when ' 'axis != 0.') axis = None if (fan_in_mode == 'FanInSum' or axis == 0) and branch_in == 'dense_after_branch_in': raise absltest.SkipTest('`FanInSum` and `FanInConcat(0)` ' 'require `is_gaussian`.') if ((axis == 1 or fan_in_mode == 'FanInProd') and branch_in == 'dense_before_branch_in'): raise absltest.SkipTest( '`FanInConcat` or `FanInProd` on feature axis requires a dense layer ' 'after concatenation or Hadamard product.') if fan_in_mode == 'FanInSum': fan_in_layer = stax.FanInSum() elif fan_in_mode == 'FanInProd': fan_in_layer = stax.FanInProd() else: fan_in_layer = stax.FanInConcat(axis) if n_branches != 2: test_utils.skip_test(self) key = random.PRNGKey(1) X0_1 = np.cos(random.normal(key, (4, 3))) X0_2 = None if same_inputs else random.normal(key, (8, 3)) width = 1024 n_samples = 256 * 2 if default_backend() == 'tpu': tol = 0.07 else: tol = 0.02 dense = stax.Dense(width, 1.25, 0.1) input_layers = [dense, stax.FanOut(n_branches)] branches = [] for b in range(n_branches): branch_layers = [FanInTest._get_phi(b)] for i in range(b): multiplier = 1 if axis not in (1, -1) else (1 + 0.25 * i) branch_layers += [ stax.Dense(int(width * multiplier), 1. + 2 * i, 0.5 + i), FanInTest._get_phi(i)] if branch_in == 'dense_before_branch_in': branch_layers += [dense] branches += [stax.serial(*branch_layers)] output_layers = [ fan_in_layer, stax.Relu() ] if branch_in == 'dense_after_branch_in': output_layers.insert(1, dense) nn = stax.serial(*(input_layers + [stax.parallel(*branches)] + output_layers)) if get == 'nngp': init_fn, apply_fn, kernel_fn = nn elif get == 'ntk': init_fn, apply_fn, kernel_fn = stax.serial(nn, stax.Dense(1, 1.25, 0.5)) else: raise ValueError(get) kernel_fn_mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if axis in (0, -2) else -1, implementation=2, vmap_axes=None if axis in (0, -2) else 0, ) exact = kernel_fn(X0_1, X0_2, get=get) empirical = kernel_fn_mc(X0_1, X0_2, get=get) test_utils.assert_close_matrices(self, empirical, exact, tol)
def test_fan_in_conv(self, same_inputs, axis, n_branches, get, branch_in, readout, fan_in_mode): test_utils.skip_test(self) if fan_in_mode in ['FanInSum', 'FanInProd']: if axis != 0: raise absltest.SkipTest('`FanInSum` and `FanInProd()` are skipped when ' 'axis != 0.') axis = None if (fan_in_mode == 'FanInSum' or axis in [0, 1, 2]) and branch_in == 'dense_after_branch_in': raise absltest.SkipTest('`FanInSum` and `FanInConcat(0/1/2)` ' 'require `is_gaussian`.') if ((axis == 3 or fan_in_mode == 'FanInProd') and branch_in == 'dense_before_branch_in'): raise absltest.SkipTest('`FanInConcat` or `FanInProd` on feature axis ' 'requires a dense layer after concatenation ' 'or Hadamard product.') if fan_in_mode == 'FanInSum': fan_in_layer = stax.FanInSum() elif fan_in_mode == 'FanInProd': fan_in_layer = stax.FanInProd() else: fan_in_layer = stax.FanInConcat(axis) key = random.PRNGKey(1) X0_1 = random.normal(key, (2, 5, 6, 3)) X0_2 = None if same_inputs else random.normal(key, (3, 5, 6, 3)) if default_backend() == 'tpu': width = 2048 n_samples = 1024 tol = 0.02 else: width = 1024 n_samples = 512 tol = 0.01 conv = stax.Conv(out_chan=width, filter_shape=(3, 3), padding='SAME', W_std=1.25, b_std=0.1) input_layers = [conv, stax.FanOut(n_branches)] branches = [] for b in range(n_branches): branch_layers = [FanInTest._get_phi(b)] for i in range(b): multiplier = 1 if axis not in (3, -1) else (1 + 0.25 * i) branch_layers += [ stax.Conv( out_chan=int(width * multiplier), filter_shape=(i + 1, 4 - i), padding='SAME', W_std=1.25 + i, b_std=0.1 + i), FanInTest._get_phi(i)] if branch_in == 'dense_before_branch_in': branch_layers += [conv] branches += [stax.serial(*branch_layers)] output_layers = [ fan_in_layer, stax.Relu(), stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten() ] if branch_in == 'dense_after_branch_in': output_layers.insert(1, conv) nn = stax.serial(*(input_layers + [stax.parallel(*branches)] + output_layers)) init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(1 if get == 'ntk' else width, 1.25, 0.5)) kernel_fn_mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if axis in (0, -4) else -1, implementation=2, vmap_axes=None if axis in (0, -4) else 0, ) exact = kernel_fn(X0_1, X0_2, get=get) empirical = kernel_fn_mc(X0_1, X0_2, get=get) test_utils.assert_close_matrices(self, empirical, exact, tol)
def test_kwargs(self, do_batch, mode): rng = random.PRNGKey(1) x_train = random.normal(rng, (8, 7, 10)) x_test = random.normal(rng, (4, 7, 10)) y_train = random.normal(rng, (8, 1)) rng_train, rng_test = random.split(rng, 2) pattern_train = random.normal(rng, (8, 7, 7)) pattern_test = random.normal(rng, (4, 7, 7)) init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(8), stax.Relu(), stax.Dropout(rate=0.4), stax.Aggregate(), stax.GlobalAvgPool(), stax.Dense(1) ) kw_dd = dict(pattern=(pattern_train, pattern_train)) kw_td = dict(pattern=(pattern_test, pattern_train)) kw_tt = dict(pattern=(pattern_test, pattern_test)) if mode == 'mc': kernel_fn = monte_carlo_kernel_fn(init_fn, apply_fn, rng, 2, batch_size=2 if do_batch else 0) elif mode == 'empirical': kernel_fn = empirical_kernel_fn(apply_fn) if do_batch: raise absltest.SkipTest('Batching of empirical kernel is not ' 'implemented with keyword arguments.') for kw in (kw_dd, kw_td, kw_tt): kw.update(dict(params=init_fn(rng, x_train.shape)[1], get=('nngp', 'ntk'))) kw_dd.update(dict(rng=(rng_train, None))) kw_td.update(dict(rng=(rng_test, rng_train))) kw_tt.update(dict(rng=(rng_test, None))) elif mode == 'analytic': if do_batch: kernel_fn = batch.batch(kernel_fn, batch_size=2) else: raise ValueError(mode) k_dd = kernel_fn(x_train, None, **kw_dd) k_td = kernel_fn(x_test, x_train, **kw_td) k_tt = kernel_fn(x_test, None, **kw_tt) # Infinite time NNGP/NTK. predict_fn_gp = predict.gp_inference(k_dd, y_train) out_gp = predict_fn_gp(k_test_train=k_td, nngp_test_test=k_tt.nngp) if mode == 'empirical': for kw in (kw_dd, kw_td, kw_tt): kw.pop('get') predict_fn_ensemble = predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train, **kw_dd) out_ensemble = predict_fn_ensemble(x_test=x_test, compute_cov=True, **kw_tt) self.assertAllClose(out_gp, out_ensemble) # Finite time NTK test. predict_fn_mse = predict.gradient_descent_mse(k_dd.ntk, y_train) out_mse = predict_fn_mse(t=1., fx_train_0=None, fx_test_0=0., k_test_train=k_td.ntk) out_ensemble = predict_fn_ensemble(t=1., get='ntk', x_test=x_test, compute_cov=False, **kw_tt) self.assertAllClose(out_mse, out_ensemble) # Finite time NNGP train. predict_fn_mse = predict.gradient_descent_mse(k_dd.nngp, y_train) out_mse = predict_fn_mse(t=2., fx_train_0=0., fx_test_0=None, k_test_train=k_td.nngp) out_ensemble = predict_fn_ensemble(t=2., get='nngp', x_test=None, compute_cov=False, **kw_dd) self.assertAllClose(out_mse, out_ensemble)