config.update('jax_numpy_rank_promotion', 'raise') BATCH_SIZES = [ 2, 4, ] WIDTH = 256 DEVICE_COUNTS = [0, 1, 2] STORE_ON_DEVICE = [True, False] ALL_GET = ('nngp', 'ntk', ('nngp', 'ntk'), None) test_utils.update_test_tolerance() def _get_inputs_and_model(width=1, n_classes=2, use_conv=True): key = random.PRNGKey(1) key, split = random.split(key) x1 = random.normal(key, (8, 4, 3, 2)) x2 = random.normal(split, (4, 4, 3, 2)) if not use_conv: x1 = np.reshape(x1, (x1.shape[0], -1)) x2 = np.reshape(x2, (x2.shape[0], -1)) init_fn, apply_fn, kernel_fn = stax.serial( stax.Conv(width, (3, 3)) if use_conv else stax.Dense(width), stax.Relu(), stax.Flatten(), stax.Dense(n_classes, 2., 0.5))
jax_config.parse_flags_with_absl() STANDARD = 'FLAT' POOLING = 'POOLING' INTERMEDIATE_CONV = 'INTERMEDIATE_CONV' # TODO(schsam): Add a pooling test when multiple inputs are supported in # Conv + Pooling. TRAIN_SHAPES = [(2, 4), (4, 8), (8, 8), (8, 4, 4, 3), (4, 3, 3, 3)] TEST_SHAPES = [(2, 4), (2, 8), (16, 8), (2, 4, 4, 3), (2, 3, 3, 3)] NETWORK = [STANDARD, STANDARD, STANDARD, STANDARD, INTERMEDIATE_CONV] OUTPUT_LOGITS = [1, 2, 3] CONVOLUTION_CHANNELS = 4 WIDTH = 1024 RTOL = 1e-2 test_utils.update_test_tolerance(f64_tol=5e-5) def _build_network(input_shape, network, out_logits, use_dropout): dropout = stax.Dropout(0.9, mode='train') if use_dropout else stax.Identity() if len(input_shape) == 1: assert network == 'FLAT' return stax.serial(stax.Dense(WIDTH, W_std=2.0, b_std=0.5), dropout, stax.Dense(out_logits, W_std=2.0, b_std=0.5)) elif len(input_shape) == 3: if network == 'POOLING': return stax.serial( stax.Conv(CONVOLUTION_CHANNELS, (2, 2), W_std=2.0, b_std=0.05), stax.GlobalAvgPool(), dropout, stax.Dense(out_logits, W_std=2.0, b_std=0.5))