Example #1
0
    def _f(p):
        out = f(p, x, **apply_fn_kwargs)
        out = utils.get_masked_array(out)
        # TODO(romann): normalize properly if output is masked.

        get_masked = utils.nt_tree_fn()(lambda o: o.masked_value)
        return get_masked(out)
Example #2
0
    def _f(p):
        fx = f(p, x, **apply_fn_kwargs)
        fx = utils.get_masked_array(fx)
        # TODO(romann): normalize properly if output is masked.

        get_masked = utils.nt_tree_fn()(lambda o: o.masked_value)
        fx = get_masked(fx)
        return _squeeze(fx, fx_axis)
Example #3
0
    def test_parallel_in_out(self, same_inputs):
        test_utils.stub_out_pmap(batch, 2)
        rng = random.PRNGKey(0)
        input_key1, input_key2, mc_key = random.split(rng, 3)

        x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 4, 10))
        x2_1, x2_2, x2_3 = random.normal(input_key2, (3, 8, 10))

        x1 = (x1_1, (x1_2, x1_3))
        x2 = (x2_1, (x2_2, x2_3))

        N = WIDTH

        def net(N_out):
            return stax.parallel(
                stax.Dense(N_out),
                stax.parallel(stax.Dense(N_out + 1), stax.Dense(N_out + 2)))

        # Check NNGP.

        readin = net(N)
        readout = net(1)

        K_readin_fn = jit(readin[2])
        K_readout_fn = jit(partial(readout[2], get='nngp'))

        batch_K_readin_fn = batch.batch(K_readin_fn, 2)
        batch_K_readout_fn = batch.batch(K_readout_fn, 2)

        test_utils.assert_close_matrices(
            self, K_readout_fn(K_readin_fn(x1, x2)),
            batch_K_readout_fn(batch_K_readin_fn(x1, x2)), RTOL)

        # Check Both.
        K_readin_fn = jit(readin[2])
        K_readout_fn = jit(partial(readout[2], get=('nngp', 'ntk')))

        batch_K_readin_fn = batch.batch(K_readin_fn, 2)
        batch_K_readout_fn = batch.batch(K_readout_fn, 2)

        get_ntk = utils.nt_tree_fn()(lambda k: k.ntk)

        test_utils.assert_close_matrices(
            self, get_ntk(K_readout_fn(K_readin_fn(x1, x2))),
            get_ntk(batch_K_readout_fn(batch_K_readin_fn(x1, x2))), RTOL)
Example #4
0
    def parallel_fn_kernel(kernel, *args, **kwargs):
        @utils.nt_tree_fn(reduce=lambda shapes: shapes[0])
        def get_batch_sizes(k):
            n1 = n2 = k.cov1.shape[0]
            if k.cov2 is not None:
                n2 = k.cov2.shape[0]
            return n1, n2

        n1, n2 = get_batch_sizes(kernel)
        _check_dropout(n1, n2, kwargs)
        n1_per_device, _device_count = _get_n_per_device(n1, n2)

        _kernel_fn = _jit_or_pmap_broadcast(kernel_fn, _device_count)

        cov2_is_none = utils.nt_tree_fn(
            reduce=lambda k: all(k))(lambda k: k.cov2 is None)(kernel)
        kernel = _reshape_kernel_for_pmap(kernel, _device_count, n1_per_device)
        kernel = _kernel_fn(kernel, *args, **kwargs)
        if cov2_is_none:
            kernel = _set_cov2_is_none(kernel)
        return _flatten_kernel(kernel, cov2_is_none, True)
Example #5
0
 def output(x, **kwargs):
     out = f(params, x, **kwargs)
     masked_output = utils.get_masked_array(out)
     return utils.nt_tree_fn()(lambda x: x.masked_value)(masked_output)
Example #6
0
    def serial_fn_kernel(k: NTTree[Kernel], *args, **kwargs) -> NTTree[Kernel]:
        # pytype: disable=attribute-error
        def get_n1_n2(k):
            if utils.is_list_or_tuple(k):
                # TODO(schsam): We might want to check for consistency here, but I can't
                # imagine a case where we could get inconsistent kernels.
                return get_n1_n2(k[0])
            return k.nngp.shape[:2]

        # pytype: enable=attribute-error
        n1, n2 = get_n1_n2(k)

        (n1_batches, n1_batch_size, n2_batches,
         n2_batch_size) = _get_n_batches_and_batch_sizes(
             n1, n2, batch_size, device_count)

        n1s = np.arange(0, n1, n1_batch_size)
        n2s = np.arange(0, n2, n2_batch_size)

        @utils.nt_tree_fn(nargs=1)
        def slice_kernel(k, n1, n2):
            return k.slice(n1, n2)

        kwargs_np1 = {}
        kwargs_np2 = {}

        kwargs_other = {}
        for key, v in kwargs.items():
            if _is_np_ndarray(v):
                assert isinstance(v, tuple) and len(v) == 2
                v1 = np.reshape(v[0], (
                    n1_batches,
                    n1_batch_size,
                ) + v[0].shape[1:])
                v2 = np.reshape(v[1], (
                    n2_batches,
                    n2_batch_size,
                ) + v[1].shape[1:])
                kwargs_np1[key] = v1
                kwargs_np2[key] = v2
            else:
                kwargs_other[key] = v

        def row_fn(_, n1):
            return _, _scan(col_fn, n1, (n2s, kwargs_np2))[1]

        def col_fn(n1, n2):
            # NOTE(schsam): If we end up wanting to enable jit-of-batch then we will
            # probably have to change this to dynamic slicing.
            n1, kwargs1 = n1
            n2, kwargs2 = n2
            kwargs_merge = {
                **kwargs_other,
                **dict((key, (kwargs1[key], kwargs2[key])) for key in kwargs1)
            }
            n1_slice = slice(n1, n1 + n1_batch_size)
            n2_slice = slice(n2, n2 + n2_batch_size)
            in_kernel = slice_kernel(k, n1_slice, n2_slice)
            return (n1, kwargs1), kernel_fn(in_kernel, *args, **kwargs_merge)

        cov2_is_none = utils.nt_tree_fn(
            reduce=lambda k: all(k))(lambda k: k.cov2 is None)(k)
        _, k = _scan(row_fn, 0, (n1s, kwargs_np1))
        if cov2_is_none:
            k = _set_cov2_is_none(k)
        return flatten(k, cov2_is_none)