Beispiel #1
0
config.update('jax_numpy_rank_promotion', 'raise')

BATCH_SIZES = [
    2,
    4,
]

WIDTH = 256

DEVICE_COUNTS = [0, 1, 2]

STORE_ON_DEVICE = [True, False]

ALL_GET = ('nngp', 'ntk', ('nngp', 'ntk'), None)

test_utils.update_test_tolerance()


def _get_inputs_and_model(width=1, n_classes=2, use_conv=True):
    key = random.PRNGKey(1)
    key, split = random.split(key)
    x1 = random.normal(key, (8, 4, 3, 2))
    x2 = random.normal(split, (4, 4, 3, 2))

    if not use_conv:
        x1 = np.reshape(x1, (x1.shape[0], -1))
        x2 = np.reshape(x2, (x2.shape[0], -1))

    init_fn, apply_fn, kernel_fn = stax.serial(
        stax.Conv(width, (3, 3)) if use_conv else stax.Dense(width),
        stax.Relu(), stax.Flatten(), stax.Dense(n_classes, 2., 0.5))
jax_config.parse_flags_with_absl()

STANDARD = 'FLAT'
POOLING = 'POOLING'
INTERMEDIATE_CONV = 'INTERMEDIATE_CONV'

# TODO(schsam): Add a pooling test when multiple inputs are supported in
# Conv + Pooling.
TRAIN_SHAPES = [(2, 4), (4, 8), (8, 8), (8, 4, 4, 3), (4, 3, 3, 3)]
TEST_SHAPES = [(2, 4), (2, 8), (16, 8), (2, 4, 4, 3), (2, 3, 3, 3)]
NETWORK = [STANDARD, STANDARD, STANDARD, STANDARD, INTERMEDIATE_CONV]
OUTPUT_LOGITS = [1, 2, 3]
CONVOLUTION_CHANNELS = 4
WIDTH = 1024
RTOL = 1e-2
test_utils.update_test_tolerance(f64_tol=5e-5)


def _build_network(input_shape, network, out_logits, use_dropout):
    dropout = stax.Dropout(0.9,
                           mode='train') if use_dropout else stax.Identity()
    if len(input_shape) == 1:
        assert network == 'FLAT'
        return stax.serial(stax.Dense(WIDTH, W_std=2.0, b_std=0.5), dropout,
                           stax.Dense(out_logits, W_std=2.0, b_std=0.5))
    elif len(input_shape) == 3:
        if network == 'POOLING':
            return stax.serial(
                stax.Conv(CONVOLUTION_CHANNELS, (2, 2), W_std=2.0, b_std=0.05),
                stax.GlobalAvgPool(), dropout,
                stax.Dense(out_logits, W_std=2.0, b_std=0.5))