示例#1
0
 def testSplit(self, dtype):
     """Test for `split`."""
     seed = constant_op.constant([1, 2], dtype=dtype)
     new_seed = stateless.split(seed, 3)
     self.assertEqual(new_seed.shape, [3, 2])
     self.assertDTypeEqual(new_seed.dtype, dtype)
     self.assertNoEqualPair([seed] + array_ops.unstack(new_seed))
示例#2
0
 def apply_fun(params, inputs, **kwargs):
     rng = kwargs.pop('rng', None)
     rngs = None
     if rng is not None:
         rngs = random.split(rng)
     else:
         rngs = (None, ) * nlayers
     for i in range(nlayers):
         inputs = apply_funs[i](params[i], inputs, rng=rngs[i], **kwargs)
     return inputs
示例#3
0
 def init_fun(rng, input_shape):
     params = []
     i = 0
     for init_fun in init_funs:
         i += 1
         keys = random.split(rng)
         rng = keys[0]
         layer_rng = keys[1]
         input_shape, param = init_fun(layer_rng, input_shape)
         params.append(param)
     return input_shape, params
 def kernel_fn_sample_once(x1: np.ndarray, x2: Optional[np.ndarray],
                           key: PRNGKey, get: Get, **apply_fn_kwargs):
     keys = random.split(key, 2)
     init_key, dropout_key = keys[0], keys[1]
     _, params = init_fn(init_key, x1.shape)
     return kernel_fn(x1,
                      x2,
                      get,
                      params,
                      rng=dropout_key,
                      **apply_fn_kwargs)
示例#5
0
 def apply_fun(params, inputs, **kwargs):
     rng = kwargs.pop('rng', None)
     rngs = None
     if rng is not None:
         rngs = random.split(rng, num=nlayers)
     else:
         rngs = (None, ) * nlayers
     result = []
     for i in range(len(apply_funs)):
         result.append(apply_funs[i](params[i],
                                     inputs[i],
                                     rng=rngs[i],
                                     **kwargs))
     return result
 def get_samples(x1: np.ndarray, x2: Optional[np.ndarray], get: Get,
                 **apply_fn_kwargs):
     _key = key
     ker_sampled = None
     for n in range(1, max(n_samples) + 1):
         keys = random.split(_key)
         _key, split = keys[0], keys[1]
         one_sample = kernel_fn_sample_once(x1, x2, split, get,
                                            **apply_fn_kwargs)
         if ker_sampled is None:
             ker_sampled = one_sample
         else:
             ker_sampled = tree_multimap(operator.add, ker_sampled,
                                         one_sample)
         yield n, ker_sampled
示例#7
0
 def init_fun(rng, input_shape):
     filter_shape_iter = iter(filter_shape)
     kernel_shape = [
         out_chan if c == 'O' else input_shape[lhs_spec.index('C')]
         if c == 'I' else next(filter_shape_iter) for c in rhs_spec
     ]
     output_shape = lax.conv_general_shape_tuple(input_shape, kernel_shape,
                                                 strides, padding,
                                                 dimension_numbers)
     bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
     bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape))
     keys = random.split(rng)
     k1 = keys[0]
     k2 = keys[1]
     W = W_init(seed=k1, shape=kernel_shape)
     b = b_init(stddev=1e-6, seed=k2, shape=bias_shape)
     return output_shape, (W, b)
示例#8
0
 def init_fun(rng, input_shape):
     output_shape = input_shape[:-1] + (out_dim, )
     keys = random.split(rng)
     k1 = keys[0]
     k2 = keys[1]
     # convert the two keys from shape (2,) into a scalar
     k1 = stateless_uniform(shape=[],
                            seed=k1,
                            minval=None,
                            maxval=None,
                            dtype=tf.int32)
     k2 = stateless_uniform(shape=[],
                            seed=k2,
                            minval=None,
                            maxval=None,
                            dtype=tf.int32)
     W = W_init(seed=k1, shape=(input_shape[-1], out_dim))
     b = b_init(seed=k2, shape=(out_dim, ))
     return output_shape, (np.asarray(W), np.asarray(b))
示例#9
0
 def init_fun(rng, input_shape):
     rngs = random.split(rng)
     result = []
     for i in range(nlayers):
         result.append(init_funs[i](rngs[i], input_shape[i]))
     return zip(*result)
示例#10
0
    def serial_fn_x1(x1: np.ndarray,
                     x2: np.ndarray = None,
                     *args,
                     **kwargs) -> _KernelType:

        x2_is_none = x2 is None
        if x2_is_none:
            # TODO(schsam): Only compute the upper triangular part of the kernel.
            x2 = x1

        n1, n2 = x1.shape[0], x2.shape[0]
        (n1_batches, n1_batch_size, n2_batches,
         n2_batch_size) = _get_n_batches_and_batch_sizes(
             n1, n2, batch_size, device_count)
        kwargs_np1 = {}
        kwargs_np2 = {}
        kwargs_other = {}
        for k, v in kwargs.items():
            if _is_np_ndarray(v):
                if k == 'rng':
                    key1, key2 = random.split(v)
                    v1 = random.split(key1, n1_batches)
                    v2 = random.split(key2, n2_batches)
                else:
                    assert isinstance(v, tuple) and len(v) == 2
                    v1 = np.reshape(v[0], (
                        n1_batches,
                        n1_batch_size,
                    ) + v[0].shape[1:])
                    v2 = np.reshape(v[1], (
                        n2_batches,
                        n2_batch_size,
                    ) + v[1].shape[1:])
                kwargs_np1[k] = v1
                kwargs_np2[k] = v2
            else:
                kwargs_other[k] = v
        input_shape = x1.shape[1:]
        x1s = np.reshape(x1, (
            n1_batches,
            n1_batch_size,
        ) + input_shape)
        x2s = np.reshape(x2, (
            n2_batches,
            n2_batch_size,
        ) + input_shape)

        def row_fn(_, x1):
            return _, _scan(col_fn, x1, (x2s, kwargs_np2))[1]

        def col_fn(x1, x2):
            x1, kwargs1 = x1
            x2, kwargs2 = x2
            kwargs_merge = {
                **kwargs_other,
                **dict((k, (kwargs1[k], kwargs2[k])) for k in kwargs1)
            }
            return (x1, kwargs1), kernel_fn(x1, x2, *args, **kwargs_merge)

        _, kernel = _scan(row_fn, 0, (x1s, kwargs_np1))
        return flatten(kernel, x2_is_none)