Example #1
0
    def test_sample_once_batch(self, batch_size, device_count, store_on_device,
                               get):
        utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fn, apply_fn, _, key = _get_inputs_and_model()
        kernel_fn = empirical.empirical_kernel_fn(apply_fn)

        sample_once_fn = monte_carlo._sample_once_kernel_fn(kernel_fn, init_fn)
        sample_once_batch_fn = monte_carlo._sample_once_kernel_fn(
            kernel_fn, init_fn, batch_size, device_count, store_on_device)

        one_sample = sample_once_fn(x1, x2, key, get)
        one_sample_batch = sample_once_batch_fn(x1, x2, key, get)
        self.assertAllClose(one_sample, one_sample_batch, True)
Example #2
0
    def test_sample_once_batch(self, batch_size, device_count, store_on_device,
                               get):
        utils.stub_out_pmap(batch, device_count)

        x1, x2, init_fn, apply_fn, _, key = _get_inputs_and_model()
        kernel_fn = empirical.empirical_kernel_fn(apply_fn)

        sample_once_fn = monte_carlo._sample_once_kernel_fn(kernel_fn, init_fn)
        sample_once_batch_fn = monte_carlo._sample_once_kernel_fn(
            kernel_fn, init_fn, batch_size, device_count, store_on_device)

        if get is None:
            if get is None:
                raise jtu.SkipTest('No default `get` values for this method.')
        else:
            one_sample = sample_once_fn(x1, x2, key, get)
            one_sample_batch = sample_once_batch_fn(x1, x2, key, get)
            self.assertAllClose(one_sample, one_sample_batch, True)