def _f(p): out = f(p, x, **apply_fn_kwargs) out = utils.get_masked_array(out) # TODO(romann): normalize properly if output is masked. get_masked = utils.nt_tree_fn()(lambda o: o.masked_value) return get_masked(out)
def _f(p): fx = f(p, x, **apply_fn_kwargs) fx = utils.get_masked_array(fx) # TODO(romann): normalize properly if output is masked. get_masked = utils.nt_tree_fn()(lambda o: o.masked_value) fx = get_masked(fx) return _squeeze(fx, fx_axis)
def test_parallel_in_out(self, same_inputs): test_utils.stub_out_pmap(batch, 2) rng = random.PRNGKey(0) input_key1, input_key2, mc_key = random.split(rng, 3) x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 4, 10)) x2_1, x2_2, x2_3 = random.normal(input_key2, (3, 8, 10)) x1 = (x1_1, (x1_2, x1_3)) x2 = (x2_1, (x2_2, x2_3)) N = WIDTH def net(N_out): return stax.parallel( stax.Dense(N_out), stax.parallel(stax.Dense(N_out + 1), stax.Dense(N_out + 2))) # Check NNGP. readin = net(N) readout = net(1) K_readin_fn = jit(readin[2]) K_readout_fn = jit(partial(readout[2], get='nngp')) batch_K_readin_fn = batch.batch(K_readin_fn, 2) batch_K_readout_fn = batch.batch(K_readout_fn, 2) test_utils.assert_close_matrices( self, K_readout_fn(K_readin_fn(x1, x2)), batch_K_readout_fn(batch_K_readin_fn(x1, x2)), RTOL) # Check Both. K_readin_fn = jit(readin[2]) K_readout_fn = jit(partial(readout[2], get=('nngp', 'ntk'))) batch_K_readin_fn = batch.batch(K_readin_fn, 2) batch_K_readout_fn = batch.batch(K_readout_fn, 2) get_ntk = utils.nt_tree_fn()(lambda k: k.ntk) test_utils.assert_close_matrices( self, get_ntk(K_readout_fn(K_readin_fn(x1, x2))), get_ntk(batch_K_readout_fn(batch_K_readin_fn(x1, x2))), RTOL)
def parallel_fn_kernel(kernel, *args, **kwargs): @utils.nt_tree_fn(reduce=lambda shapes: shapes[0]) def get_batch_sizes(k): n1 = n2 = k.cov1.shape[0] if k.cov2 is not None: n2 = k.cov2.shape[0] return n1, n2 n1, n2 = get_batch_sizes(kernel) _check_dropout(n1, n2, kwargs) n1_per_device, _device_count = _get_n_per_device(n1, n2) _kernel_fn = _jit_or_pmap_broadcast(kernel_fn, _device_count) cov2_is_none = utils.nt_tree_fn( reduce=lambda k: all(k))(lambda k: k.cov2 is None)(kernel) kernel = _reshape_kernel_for_pmap(kernel, _device_count, n1_per_device) kernel = _kernel_fn(kernel, *args, **kwargs) if cov2_is_none: kernel = _set_cov2_is_none(kernel) return _flatten_kernel(kernel, cov2_is_none, True)
def output(x, **kwargs): out = f(params, x, **kwargs) masked_output = utils.get_masked_array(out) return utils.nt_tree_fn()(lambda x: x.masked_value)(masked_output)
def serial_fn_kernel(k: NTTree[Kernel], *args, **kwargs) -> NTTree[Kernel]: # pytype: disable=attribute-error def get_n1_n2(k): if utils.is_list_or_tuple(k): # TODO(schsam): We might want to check for consistency here, but I can't # imagine a case where we could get inconsistent kernels. return get_n1_n2(k[0]) return k.nngp.shape[:2] # pytype: enable=attribute-error n1, n2 = get_n1_n2(k) (n1_batches, n1_batch_size, n2_batches, n2_batch_size) = _get_n_batches_and_batch_sizes( n1, n2, batch_size, device_count) n1s = np.arange(0, n1, n1_batch_size) n2s = np.arange(0, n2, n2_batch_size) @utils.nt_tree_fn(nargs=1) def slice_kernel(k, n1, n2): return k.slice(n1, n2) kwargs_np1 = {} kwargs_np2 = {} kwargs_other = {} for key, v in kwargs.items(): if _is_np_ndarray(v): assert isinstance(v, tuple) and len(v) == 2 v1 = np.reshape(v[0], ( n1_batches, n1_batch_size, ) + v[0].shape[1:]) v2 = np.reshape(v[1], ( n2_batches, n2_batch_size, ) + v[1].shape[1:]) kwargs_np1[key] = v1 kwargs_np2[key] = v2 else: kwargs_other[key] = v def row_fn(_, n1): return _, _scan(col_fn, n1, (n2s, kwargs_np2))[1] def col_fn(n1, n2): # NOTE(schsam): If we end up wanting to enable jit-of-batch then we will # probably have to change this to dynamic slicing. n1, kwargs1 = n1 n2, kwargs2 = n2 kwargs_merge = { **kwargs_other, **dict((key, (kwargs1[key], kwargs2[key])) for key in kwargs1) } n1_slice = slice(n1, n1 + n1_batch_size) n2_slice = slice(n2, n2 + n2_batch_size) in_kernel = slice_kernel(k, n1_slice, n2_slice) return (n1, kwargs1), kernel_fn(in_kernel, *args, **kwargs_merge) cov2_is_none = utils.nt_tree_fn( reduce=lambda k: all(k))(lambda k: k.cov2 is None)(k) _, k = _scan(row_fn, 0, (n1s, kwargs_np1)) if cov2_is_none: k = _set_cov2_is_none(k) return flatten(k, cov2_is_none)