def test_nested_parallel(self, same_inputs, kernel_type): platform = default_backend() rtol = RTOL if platform != 'tpu' else 0.05 rng = random.PRNGKey(0) (input_key1, input_key2, input_key3, input_key4, mask_key, mc_key) = random.split(rng, 6) x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 5)) x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 2, 2, 2)) x1_3, x2_3 = _get_inputs(input_key3, same_inputs, (BATCH_SIZE, 2, 2, 3)) x1_4, x2_4 = _get_inputs(input_key4, same_inputs, (BATCH_SIZE, 3, 4)) m1_key, m2_key, m3_key, m4_key = random.split(mask_key, 4) x1_1 = test_utils.mask( x1_1, mask_constant=-1, mask_axis=(1,), key=m1_key, p=0.5) x1_2 = test_utils.mask( x1_2, mask_constant=-1, mask_axis=(2, 3,), key=m2_key, p=0.5) if not same_inputs: x2_3 = test_utils.mask( x2_3, mask_constant=-1, mask_axis=(1, 3,), key=m3_key, p=0.5) x2_4 = test_utils.mask( x2_4, mask_constant=-1, mask_axis=(2,), key=m4_key, p=0.5) x1 = (((x1_1, x1_2), x1_3), x1_4) x2 = (((x2_1, x2_2), x2_3), x2_4) if not same_inputs else None N_in = 2 ** 7 # We only include dropout on non-TPU backends, because it takes large N to # converge on TPU. dropout_or_id = stax.Dropout(0.9) if platform != 'tpu' else stax.Identity() init_fn, apply_fn, kernel_fn = stax.parallel( stax.parallel( stax.parallel(stax.Dense(N_in), stax.serial(stax.Conv(N_in + 1, (2, 2)), stax.Flatten())), stax.serial(stax.Conv(N_in + 2, (2, 2)), dropout_or_id, stax.GlobalAvgPool())), stax.Conv(N_in + 3, (2,))) kernel_fn_empirical = nt.monte_carlo_kernel_fn( init_fn, apply_fn, mc_key, N_SAMPLES, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=(((((0, 0), 0), 0), (((0, 0), 0), 0), {}) if platform == 'tpu' else None) ) test_utils.assert_close_matrices( self, kernel_fn(x1, x2, get=kernel_type, mask_constant=-1), kernel_fn_empirical(x1, x2, get=kernel_type, mask_constant=-1), rtol)
def test_mask_conv(self, same_inputs, get, mask_axis, mask_constant, concat, proj, p, n, transpose): if isinstance(concat, int) and concat > n: raise absltest.SkipTest('Concatenation axis out of bounds.') test_utils.skip_test(self) if default_backend() == 'gpu' and n > 3: raise absltest.SkipTest('>=4D-CNN is not supported on GPUs.') width = 256 n_samples = 256 tol = 0.03 key = random.PRNGKey(1) spatial_shape = ((1, 2, 3, 2, 1) if transpose else (15, 8, 9))[:n] filter_shape = ((2, 3, 1, 2, 1) if transpose else (7, 2, 3))[:n] strides = (2, 1, 3, 2, 3)[:n] spatial_spec = 'HWDZX'[:n] dimension_numbers = ('N' + spatial_spec + 'C', 'OI' + spatial_spec, 'N' + spatial_spec + 'C') x1 = np.cos(random.normal(key, (2, ) + spatial_shape + (2, ))) x1 = test_utils.mask(x1, mask_constant, mask_axis, key, p) if same_inputs: x2 = None else: x2 = np.cos(random.normal(key, (4, ) + spatial_shape + (2, ))) x2 = test_utils.mask(x2, mask_constant, mask_axis, key, p) def get_attn(): return stax.GlobalSelfAttention( n_chan_out=width, n_chan_key=width, n_chan_val=int(np.round(float(width) / int(np.sqrt(width)))), n_heads=int(np.sqrt(width)), ) if proj == 'avg' else stax.Identity() conv = stax.ConvTranspose if transpose else stax.Conv nn = stax.serial( stax.FanOut(3), stax.parallel( stax.serial( conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='CIRCULAR', W_std=1.5, b_std=0.2), stax.LayerNorm(axis=(1, -1)), stax.Abs(), stax.DotGeneral(rhs=0.9), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='VALID', W_std=1.2, b_std=0.1), ), stax.serial( conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='SAME', W_std=0.1, b_std=0.3), stax.Relu(), stax.Dropout(0.7), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='VALID', W_std=0.9, b_std=1.), ), stax.serial( get_attn(), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='CIRCULAR', W_std=1., b_std=0.1), stax.Erf(), stax.Dropout(0.2), stax.DotGeneral(rhs=0.7), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='VALID', W_std=1., b_std=0.1), )), (stax.FanInSum() if concat is None else stax.FanInConcat(concat)), get_attn(), { 'avg': stax.GlobalAvgPool(), 'sum': stax.GlobalSumPool(), 'flatten': stax.Flatten(), }[proj], ) if get == 'nngp': init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(width, 1., 0.)) elif get == 'ntk': init_fn, apply_fn, kernel_fn = stax.serial(nn, stax.Dense(1, 1., 0.)) else: raise ValueError(get) kernel_fn_mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if concat in (0, -n) else -1, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=None if concat in (0, -n) else 0, ) kernel_fn = jit(kernel_fn, static_argnames='get') exact = kernel_fn(x1, x2, get, mask_constant=mask_constant) empirical = kernel_fn_mc(x1, x2, get=get, mask_constant=mask_constant) test_utils.assert_close_matrices(self, empirical, exact, tol)
def test_mask_fc(self, same_inputs, get, concat, p, mask_axis, mask_constant): width = 512 n_samples = 128 tol = 0.04 key = random.PRNGKey(1) x1 = random.normal(key, (4, 6, 5, 7)) x1 = test_utils.mask(x1, mask_constant, mask_axis, key, p) if same_inputs: x2 = None else: x2 = random.normal(key, (2, 6, 5, 7)) x2 = test_utils.mask(x2, mask_constant, mask_axis, key, p) nn = stax.serial( stax.Flatten(), stax.FanOut(3), stax.parallel( stax.serial( stax.Dense(width, 1., 0.1), stax.Abs(), stax.DotGeneral(lhs=-0.2), stax.Dense(width, 1.5, 0.01), ), stax.serial( stax.Dense(width, 1.1, 0.1), stax.DotGeneral(rhs=0.7), stax.Erf(), stax.Dense(width if concat != 1 else 512, 1.5, 0.1), ), stax.serial( stax.DotGeneral(rhs=0.5), stax.Dense(width, 1.2), stax.ABRelu(-0.2, 0.4), stax.Dense(width if concat != 1 else 1024, 1.3, 0.2), )), (stax.FanInSum() if concat is None else stax.FanInConcat(concat)), stax.Dense(width, 2., 0.01), stax.Relu()) if get == 'nngp': init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(width, 1., 0.1)) elif get == 'ntk': init_fn, apply_fn, kernel_fn = stax.serial(nn, stax.Dense(1, 1., 0.1)) else: raise ValueError(get) kernel_fn_mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if concat in (0, -2) else -1, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=None if concat in (0, -2) else 0, ) kernel_fn = jit(kernel_fn, static_argnames='get') exact = kernel_fn(x1, x2, get, mask_constant=mask_constant) empirical = kernel_fn_mc(x1, x2, get=get, mask_constant=mask_constant) test_utils.assert_close_matrices(self, empirical, exact, tol)