コード例 #1
0
    def testComposition(self, train_shape, test_shape, network, name,
                        kernel_fn, batch_size):
        test_utils.stub_out_pmap(batch, 2)

        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        keys = tf_random_split(key, 3)
        key = keys[0]
        self_split = keys[1]
        other_split = keys[2]
        data_self = np.asarray(normal(train_shape, seed=self_split))
        data_other = np.asarray(normal(test_shape, seed=other_split))

        kernel_fn = kernel_fn(key, train_shape[1:], network)

        kernel_batched = batch._parallel(
            batch._serial(kernel_fn, batch_size=batch_size))
        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)

        kernel_batched = batch._serial(batch._parallel(kernel_fn),
                                       batch_size=batch_size)
        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)
コード例 #2
0
ファイル: batch_test.py プロジェクト: saonam/neural-tangents
  def testComposition(self, train_shape, test_shape, network, name, kernel_fn):
    utils.stub_out_pmap(batch, 2)

    key = random.PRNGKey(0)
    key, self_split, other_split = random.split(key, 3)
    data_self = random.normal(self_split, train_shape)
    data_other = random.normal(other_split, test_shape)

    kernel_fn = kernel_fn(key, train_shape[1:], network)

    kernel_batched = batch._parallel(batch._serial(kernel_fn, batch_size=2))
    _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self,
                                 data_other)

    kernel_batched = batch._serial(batch._parallel(kernel_fn), batch_size=2)
    _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self,
                                 data_other)
コード例 #3
0
    def testSerial(self, train_shape, test_shape, network, name, kernel_fn,
                   batch_size):
        key = random.PRNGKey(0)
        key, self_split, other_split = random.split(key, 3)
        data_self = random.normal(self_split, train_shape)
        data_other = random.normal(other_split, test_shape)
        kernel_fn = kernel_fn(key, train_shape[1:], network)
        kernel_batched = batch._serial(kernel_fn, batch_size=batch_size)

        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)