def testSplit(self, dtype): """Test for `split`.""" seed = constant_op.constant([1, 2], dtype=dtype) new_seed = stateless.split(seed, 3) self.assertEqual(new_seed.shape, [3, 2]) self.assertDTypeEqual(new_seed.dtype, dtype) self.assertNoEqualPair([seed] + array_ops.unstack(new_seed))
def apply_fun(params, inputs, **kwargs): rng = kwargs.pop('rng', None) rngs = None if rng is not None: rngs = random.split(rng) else: rngs = (None, ) * nlayers for i in range(nlayers): inputs = apply_funs[i](params[i], inputs, rng=rngs[i], **kwargs) return inputs
def init_fun(rng, input_shape): params = [] i = 0 for init_fun in init_funs: i += 1 keys = random.split(rng) rng = keys[0] layer_rng = keys[1] input_shape, param = init_fun(layer_rng, input_shape) params.append(param) return input_shape, params
def kernel_fn_sample_once(x1: np.ndarray, x2: Optional[np.ndarray], key: PRNGKey, get: Get, **apply_fn_kwargs): keys = random.split(key, 2) init_key, dropout_key = keys[0], keys[1] _, params = init_fn(init_key, x1.shape) return kernel_fn(x1, x2, get, params, rng=dropout_key, **apply_fn_kwargs)
def apply_fun(params, inputs, **kwargs): rng = kwargs.pop('rng', None) rngs = None if rng is not None: rngs = random.split(rng, num=nlayers) else: rngs = (None, ) * nlayers result = [] for i in range(len(apply_funs)): result.append(apply_funs[i](params[i], inputs[i], rng=rngs[i], **kwargs)) return result
def get_samples(x1: np.ndarray, x2: Optional[np.ndarray], get: Get, **apply_fn_kwargs): _key = key ker_sampled = None for n in range(1, max(n_samples) + 1): keys = random.split(_key) _key, split = keys[0], keys[1] one_sample = kernel_fn_sample_once(x1, x2, split, get, **apply_fn_kwargs) if ker_sampled is None: ker_sampled = one_sample else: ker_sampled = tree_multimap(operator.add, ker_sampled, one_sample) yield n, ker_sampled
def init_fun(rng, input_shape): filter_shape_iter = iter(filter_shape) kernel_shape = [ out_chan if c == 'O' else input_shape[lhs_spec.index('C')] if c == 'I' else next(filter_shape_iter) for c in rhs_spec ] output_shape = lax.conv_general_shape_tuple(input_shape, kernel_shape, strides, padding, dimension_numbers) bias_shape = [out_chan if c == 'C' else 1 for c in out_spec] bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape)) keys = random.split(rng) k1 = keys[0] k2 = keys[1] W = W_init(seed=k1, shape=kernel_shape) b = b_init(stddev=1e-6, seed=k2, shape=bias_shape) return output_shape, (W, b)
def init_fun(rng, input_shape): output_shape = input_shape[:-1] + (out_dim, ) keys = random.split(rng) k1 = keys[0] k2 = keys[1] # convert the two keys from shape (2,) into a scalar k1 = stateless_uniform(shape=[], seed=k1, minval=None, maxval=None, dtype=tf.int32) k2 = stateless_uniform(shape=[], seed=k2, minval=None, maxval=None, dtype=tf.int32) W = W_init(seed=k1, shape=(input_shape[-1], out_dim)) b = b_init(seed=k2, shape=(out_dim, )) return output_shape, (np.asarray(W), np.asarray(b))
def init_fun(rng, input_shape): rngs = random.split(rng) result = [] for i in range(nlayers): result.append(init_funs[i](rngs[i], input_shape[i])) return zip(*result)
def serial_fn_x1(x1: np.ndarray, x2: np.ndarray = None, *args, **kwargs) -> _KernelType: x2_is_none = x2 is None if x2_is_none: # TODO(schsam): Only compute the upper triangular part of the kernel. x2 = x1 n1, n2 = x1.shape[0], x2.shape[0] (n1_batches, n1_batch_size, n2_batches, n2_batch_size) = _get_n_batches_and_batch_sizes( n1, n2, batch_size, device_count) kwargs_np1 = {} kwargs_np2 = {} kwargs_other = {} for k, v in kwargs.items(): if _is_np_ndarray(v): if k == 'rng': key1, key2 = random.split(v) v1 = random.split(key1, n1_batches) v2 = random.split(key2, n2_batches) else: assert isinstance(v, tuple) and len(v) == 2 v1 = np.reshape(v[0], ( n1_batches, n1_batch_size, ) + v[0].shape[1:]) v2 = np.reshape(v[1], ( n2_batches, n2_batch_size, ) + v[1].shape[1:]) kwargs_np1[k] = v1 kwargs_np2[k] = v2 else: kwargs_other[k] = v input_shape = x1.shape[1:] x1s = np.reshape(x1, ( n1_batches, n1_batch_size, ) + input_shape) x2s = np.reshape(x2, ( n2_batches, n2_batch_size, ) + input_shape) def row_fn(_, x1): return _, _scan(col_fn, x1, (x2s, kwargs_np2))[1] def col_fn(x1, x2): x1, kwargs1 = x1 x2, kwargs2 = x2 kwargs_merge = { **kwargs_other, **dict((k, (kwargs1[k], kwargs2[k])) for k in kwargs1) } return (x1, kwargs1), kernel_fn(x1, x2, *args, **kwargs_merge) _, kernel = _scan(row_fn, 0, (x1s, kwargs_np1)) return flatten(kernel, x2_is_none)