def test_layernorm(self, model, width, same_inputs, is_ntk, proj_into_2d, layer_norm): is_conv = 'conv' in model # Check for duplicate / incorrectly-shaped NN configs / wrong backend. if is_conv: if xla_bridge.get_backend().platform == 'cpu': raise jtu.SkipTest( 'Not running CNN models on CPU to save time.') elif proj_into_2d != PROJECTIONS[0] or layer_norm != LAYER_NORM[0]: raise jtu.SkipTest('FC models do not have these parameters.') W_std, b_std = 2.**0.5, 0.5**0.5 filter_size = FILTER_SIZES[0] padding = PADDINGS[0] strides = STRIDES[0] phi = stax.Relu() use_pooling, is_res = False, False parameterization = 'ntk' use_dropout = False self._check_agreement_with_empirical(W_std, b_std, filter_size, is_conv, is_ntk, is_res, layer_norm, padding, phi, proj_into_2d, same_inputs, strides, use_pooling, width, parameterization, use_dropout)
def test_exact(self, model, width, strides, padding, phi, same_inputs, filter_size, use_pooling, is_ntk, is_res, proj_into_2d): is_conv = 'conv' in model # Check for duplicate / incorrectly-shaped NN configs / wrong backend. if is_conv: if xla_bridge.get_backend().platform == 'cpu': raise jtu.SkipTest('Not running CNN models on CPU to save time.') if (is_res and is_conv and ((strides is not None and strides != (1, 1)) or (padding == 'VALID' and filter_size != (1, 1)))): raise jtu.SkipTest('Different paths in a residual models need to return' ' outputs of the same shape.') elif (filter_size != FILTER_SIZES[0] or padding != PADDINGS[0] or strides != STRIDES[0] or proj_into_2d != PROJECTIONS[0] or use_pooling): raise jtu.SkipTest('FC models do not have these parameters.') if (proj_into_2d.startswith('ATTN') and strides == (2, 1) and padding == 'VALID' and xla_bridge.get_backend().platform == 'tpu'): #TODO: speed up the vmap alternative impl or fix the current one raise jtu.SkipTest('ATTN forward pass on TPU is broken if one of' ' the spatial dimensions is singleton.') W_std, b_std = 2.**0.5, 0.5**0.5 layer_norm = None self._check_agreement_with_empirical(W_std, b_std, filter_size, is_conv, is_ntk, is_res, layer_norm, padding, phi, proj_into_2d, same_inputs, strides, use_pooling, width)
def test_dropout(self, model, width, same_inputs, is_ntk, padding, strides, filter_shape, phi, use_pooling, proj_into_2d): if xla_bridge.get_backend().platform == 'tpu' and same_inputs: raise jtu.SkipTest( 'Skip TPU test for `same_inputs`. Need to handle ' 'random keys carefully for dropout + empirical kernel.') pool_type = 'AVG' use_dropout = True is_conv = 'conv' in model is_res = False # Check for duplicate / incorrectly-shaped NN configs / wrong backend. W_std, b_std = 2.**0.5, 0.5**0.5 layer_norm = None parameterization = 'ntk' if is_conv: if xla_bridge.get_backend().platform == 'cpu': raise jtu.SkipTest('Not running CNN models on CPU to save time.') if (is_res and is_conv and ((strides is not None and strides != (1, 1)) or (padding == 'VALID' and filter_shape != (1, 1)))): raise jtu.SkipTest('Different paths in a residual models need to return' ' outputs of the same shape.') elif (filter_shape != FILTER_SHAPES[0] or padding != PADDINGS[0] or strides != STRIDES[0] or proj_into_2d != PROJECTIONS[0] or use_pooling): raise jtu.SkipTest('FC models do not have these parameters.') net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm, parameterization, use_dropout) self._check_agreement_with_empirical(net, same_inputs, is_conv, use_dropout, is_ntk, proj_into_2d)
def test_parameterizations(self, model, width, same_inputs, is_ntk, filter_shape, proj_into_2d, parameterization): is_conv = 'conv' in model W_std, b_std = 2.**0.5, 0.5**0.5 padding = PADDINGS[0] strides = STRIDES[0] phi = stax.Relu() use_pooling, is_res = False, False layer_norm = None pool_type = 'AVG' use_dropout = False # Check for duplicate / incorrectly-shaped NN configs / wrong backend. if is_conv: if xla_bridge.get_backend().platform == 'cpu': raise jtu.SkipTest('Not running CNN models on CPU to save time.') elif proj_into_2d != PROJECTIONS[0]: raise jtu.SkipTest('FC models do not have these parameters.') net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm, parameterization, use_dropout) self._check_agreement_with_empirical(net, same_inputs, is_conv, use_dropout, is_ntk, proj_into_2d)
def test_exact(self, model, width, strides, padding, phi, same_inputs, filter_size, use_pooling, is_ntk, is_res, proj_into_2d): is_conv = 'conv' in model # Check for duplicate / incorrectly-shaped NN configs / wrong backend. if is_conv: if xla_bridge.get_backend().platform == 'cpu': raise jtu.SkipTest('Not running CNN models on CPU to save time.') if (is_res and is_conv and ((strides is not None and strides != (1, 1)) or (padding == 'VALID' and filter_size != (1, 1)))): raise jtu.SkipTest('Different paths in a residual models need to return' ' outputs of the same shape.') elif (filter_size != FILTER_SIZES[0] or padding != PADDINGS[0] or strides != STRIDES[0] or proj_into_2d != PROJECTIONS[0] or use_pooling): raise jtu.SkipTest('FC models do not have these parameters.') if (proj_into_2d.startswith('ATTN') and strides == (2, 1) and padding == 'VALID' and xla_bridge.get_backend().platform == 'tpu'): #TODO(jirihron): speed up the vmap alternative impl or fix the current one raise jtu.SkipTest('ATTN forward pass on TPU is broken if one of' ' the spatial dimensions is singleton.') W_std, b_std = 2.**0.5, 0.5**0.5 key = random.PRNGKey(1) x1, x2 = _get_inputs(key, is_conv, same_inputs, INPUT_SHAPE) init_fn, apply_fn, kernel_fn = _get_net(W_std, b_std, filter_size, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d) def _get_empirical(n_samples, get): kernel_fn_empirical = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples) return kernel_fn_empirical(x1, x2, get) if proj_into_2d == 'ATTN_PARAM': # no analytic kernel available, just test forward/backward pass _get_empirical(1, 'ntk' if is_ntk else 'nngp') else: if is_ntk: exact = kernel_fn(x1, x2, 'ntk') empirical = np.reshape(_get_empirical(N_SAMPLES, 'ntk'), exact.shape) else: exact = kernel_fn(x1, x2, 'nngp') empirical = _get_empirical(N_SAMPLES, 'nngp') utils.assert_close_matrices(self, empirical, exact, RTOL)
def test_exact(self, model, width, strides, padding, phi, same_inputs, filter_size, use_pooling, is_ntk, is_res): is_conv = 'conv' in model # Check for duplicate / incorrectly-shaped NN configs / wrong backend. if is_conv: if xla_bridge.get_backend().platform == 'cpu': raise jtu.SkipTest( 'Not running CNN models on CPU to save time.') if use_pooling and not same_inputs: raise jtu.SkipTest( 'Pooling layers for different inputs or for same ' 'padding not implemented.') if (is_res and is_conv and ((strides is not None and strides != (1, 1)) or (padding == 'VALID' and filter_size != (1, 1)))): raise jtu.SkipTest( 'Different paths in a residual models need to return' ' outputs of the same shape.') elif (filter_size != FILTER_SIZES[0] or padding != PADDINGS[0] or strides != STRIDES[0] or use_pooling): raise jtu.SkipTest('FC models do not have these parameters.') W_std, b_std = 2.**0.5, 0.5**0.5 key = random.PRNGKey(1) x1, x2 = _get_inputs(key, is_conv, same_inputs, INPUT_SHAPE) init_fun, apply_fun, ker_fun = _get_net(W_std, b_std, filter_size, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk) if is_ntk: exact = ker_fun(x1, x2).ntk ker_fun_empirical = monte_carlo.get_ker_fun_monte_carlo( init_fun, apply_fun, False, True) empirical = ker_fun_empirical(x1, x2, key, N_SAMPLES).ntk empirical = np.reshape(empirical, exact.shape) else: exact = ker_fun(x1, x2, compute_ntk=False).nngp ker_fun_empirical = monte_carlo.get_ker_fun_monte_carlo( init_fun, apply_fun, True, False) empirical = ker_fun_empirical(x1, x2, key, N_SAMPLES).nngp utils.assert_close_matrices(self, empirical, exact, RTOL)
def test_pool(self, width, same_inputs, is_ntk, pool_type, padding, filter_shape, strides, normalize_edges): is_conv = True use_dropout = False proj_into_2d = 'POOL' # Check for duplicate / incorrectly-shaped NN configs / wrong backend. if xla_bridge.get_backend().platform == 'cpu': raise jtu.SkipTest('Not running CNN models on CPU to save time.') if pool_type == 'SUM' and normalize_edges: raise jtu.SkipTest('normalize_edges not applicable to SumPool.') net = _get_net_pool(width, is_ntk, pool_type, padding, filter_shape, strides, normalize_edges) self._check_agreement_with_empirical(net, same_inputs, is_conv, use_dropout, is_ntk, proj_into_2d)
def test_batch_sample_once(self, batch_size, device_count, store_on_device, get): utils.stub_out_pmap(batch, device_count) x1, x2, init_fn, apply_fn, _, key = _get_inputs_and_model() kernel_fn = empirical.empirical_kernel_fn(apply_fn) sample_once_fn = monte_carlo._sample_once_kernel_fn(kernel_fn, init_fn, device_count=0) batch_sample_once_fn = batch.batch(sample_once_fn, batch_size, device_count, store_on_device) if get is None: raise jtu.SkipTest('No default `get` values for this method.') else: one_sample = sample_once_fn(x1, x2, key, get) one_batch_sample = batch_sample_once_fn(x1, x2, key, get) self.assertAllClose(one_sample, one_batch_sample, True)
def testTrainedEnsemblePredCov(self, train_shape, test_shape, network, out_logits): if xla_bridge.get_backend().platform == 'gpu' and config.read( 'jax_enable_x64'): raise jtu.SkipTest('Not running GPU x64 to save time.') training_steps = 5000 learning_rate = 1.0 ensemble_size = 50 init_fn, apply_fn, ker_fn = stax.serial( stax.Dense(1024, W_std=1.2, b_std=0.05), stax.Erf(), stax.Dense(out_logits, W_std=1.2, b_std=0.05)) opt_init, opt_update, get_params = optimizers.sgd(learning_rate) opt_update = jit(opt_update) key = random.PRNGKey(0) key, = random.split(key, 1) key, split = random.split(key) x_train = np.cos(random.normal(split, train_shape)) key, split = random.split(key) y_train = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) train = (x_train, y_train) key, split = random.split(key) x_test = np.cos(random.normal(split, test_shape)) ensemble_key = random.split(key, ensemble_size) loss = jit(lambda params, x, y: 0.5 * np.mean( (apply_fn(params, x) - y)**2)) grad_loss = jit(lambda state, x, y: grad(loss) (get_params(state), x, y)) def train_network(key): _, params = init_fn(key, (-1, ) + train_shape[1:]) opt_state = opt_init(params) for i in range(training_steps): opt_state = opt_update(i, grad_loss(opt_state, *train), opt_state) return get_params(opt_state) params = vmap(train_network)(ensemble_key) ensemble_fx = vmap(apply_fn, (0, None))(params, x_test) ensemble_loss = vmap(loss, (0, None, None))(params, x_train, y_train) ensemble_loss = np.mean(ensemble_loss) self.assertLess(ensemble_loss, 1e-5, True) mean_emp = np.mean(ensemble_fx, axis=0) mean_subtracted = ensemble_fx - mean_emp cov_emp = np.einsum( 'ijk,ilk->jl', mean_subtracted, mean_subtracted, optimize=True) / ( mean_subtracted.shape[0] * mean_subtracted.shape[-1]) reg = 1e-7 ntk_predictions = predict.gp_inference(ker_fn, x_train, y_train, x_test, 'ntk', reg, compute_cov=True) self.assertAllClose(mean_emp, ntk_predictions.mean, True, RTOL, ATOL) self.assertAllClose(cov_emp, ntk_predictions.covariance, True, RTOL, ATOL)
def test_fan_in_conv(self, same_inputs, axis, n_branches, get, branch_in, readout): if xla_bridge.get_backend().platform == 'cpu': raise jtu.SkipTest('Not running CNNs on CPU to save time.') if axis in (None, 0, 1, 2) and branch_in == 'dense_after_branch_in': raise jtu.SkipTest('`FanInSum` and `FanInConcat(0/1/2)` ' 'require `is_gaussian`.') if axis == 3 and branch_in == 'dense_before_branch_in': raise jtu.SkipTest('`FanInConcat` on feature axis requires a dense layer ' 'after concatenation.') key = random.PRNGKey(1) X0_1 = random.normal(key, (2, 5, 6, 3)) X0_2 = None if same_inputs else random.normal(key, (3, 5, 6, 3)) if xla_bridge.get_backend().platform == 'tpu': width = 2048 n_samples = 1024 tol = 0.02 else: width = 1024 n_samples = 512 tol = 0.01 conv = stax.Conv(out_chan=width, filter_shape=(3, 3), padding='SAME', W_std=1.25, b_std=0.1) input_layers = [conv, stax.FanOut(n_branches)] branches = [] for b in range(n_branches): branch_layers = [FanInTest._get_phi(b)] for i in range(b): branch_layers += [ stax.Conv( out_chan=width, filter_shape=(i + 1, 4 - i), padding='SAME', W_std=1.25 + i, b_std=0.1 + i), FanInTest._get_phi(i)] if branch_in == 'dense_before_branch_in': branch_layers += [conv] branches += [stax.serial(*branch_layers)] output_layers = [ stax.FanInSum() if axis is None else stax.FanInConcat(axis), stax.Relu(), stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten() ] if branch_in == 'dense_after_branch_in': output_layers.insert(1, conv) nn = stax.serial(*(input_layers + [stax.parallel(*branches)] + output_layers)) init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(1 if get == 'ntk' else width, 1.25, 0.5)) kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if axis in (0, -4) else -1) exact = kernel_fn(X0_1, X0_2, get=get) empirical = kernel_fn_mc(X0_1, X0_2, get=get) empirical = empirical.reshape(exact.shape) utils.assert_close_matrices(self, empirical, exact, tol)
def test_fan_in_fc(self, same_inputs, axis, n_branches, get, branch_in): if axis in (None, 0) and branch_in == 'dense_after_branch_in': raise jtu.SkipTest('`FanInSum` and `FanInConcat(0)` ' 'require `is_gaussian`.') if axis == 1 and branch_in == 'dense_before_branch_in': raise jtu.SkipTest('`FanInConcat` on feature axis requires a dense layer' 'after concatenation.') key = random.PRNGKey(1) X0_1 = random.normal(key, (10, 20)) X0_2 = None if same_inputs else random.normal(key, (8, 20)) if xla_bridge.get_backend().platform == 'tpu': width = 2048 n_samples = 1024 tol = 0.02 else: width = 1024 n_samples = 256 tol = 0.01 dense = stax.Dense(width, 1.25, 0.1) input_layers = [dense, stax.FanOut(n_branches)] branches = [] for b in range(n_branches): branch_layers = [FanInTest._get_phi(b)] for i in range(b): branch_layers += [ stax.Dense(width, 1. + 2 * i, 0.5 + i), FanInTest._get_phi(i)] if branch_in == 'dense_before_branch_in': branch_layers += [dense] branches += [stax.serial(*branch_layers)] output_layers = [ stax.FanInSum() if axis is None else stax.FanInConcat(axis), stax.Relu() ] if branch_in == 'dense_after_branch_in': output_layers.insert(1, dense) nn = stax.serial(*(input_layers + [stax.parallel(*branches)] + output_layers)) if get == 'nngp': init_fn, apply_fn, kernel_fn = nn elif get == 'ntk': init_fn, apply_fn, kernel_fn = stax.serial(nn, stax.Dense(1, 1.25, 0.5)) else: raise ValueError(get) kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0) exact = kernel_fn(X0_1, X0_2, get=get) empirical = kernel_fn_mc(X0_1, X0_2, get=get) empirical = empirical.reshape(exact.shape) utils.assert_close_matrices(self, empirical, exact, tol)