def testComposition(self, train_shape, test_shape, network, name, kernel_fn, batch_size): test_utils.stub_out_pmap(batch, 2) key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) keys = tf_random_split(key, 3) key = keys[0] self_split = keys[1] other_split = keys[2] data_self = np.asarray(normal(train_shape, seed=self_split)) data_other = np.asarray(normal(test_shape, seed=other_split)) kernel_fn = kernel_fn(key, train_shape[1:], network) kernel_batched = batch._parallel( batch._serial(kernel_fn, batch_size=batch_size)) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other) kernel_batched = batch._serial(batch._parallel(kernel_fn), batch_size=batch_size) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other)
def testComposition(self, train_shape, test_shape, network, name, kernel_fn): utils.stub_out_pmap(batch, 2) key = random.PRNGKey(0) key, self_split, other_split = random.split(key, 3) data_self = random.normal(self_split, train_shape) data_other = random.normal(other_split, test_shape) kernel_fn = kernel_fn(key, train_shape[1:], network) kernel_batched = batch._parallel(batch._serial(kernel_fn, batch_size=2)) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other) kernel_batched = batch._serial(batch._parallel(kernel_fn), batch_size=2) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other)
def testSerial(self, train_shape, test_shape, network, name, kernel_fn, batch_size): key = random.PRNGKey(0) key, self_split, other_split = random.split(key, 3) data_self = random.normal(self_split, train_shape) data_other = random.normal(other_split, test_shape) kernel_fn = kernel_fn(key, train_shape[1:], network) kernel_batched = batch._serial(kernel_fn, batch_size=batch_size) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other)