def apply_fun(params, inputs, **kwargs): rng = kwargs.pop('rng', None) rngs = None if rng is not None: rngs = split(seed=tf.convert_to_tensor(rng, dtype=tf.int32), num=nlayers) else: rngs = (None,) * nlayers for i in range(nlayers): inputs = apply_funs[i](params[i], inputs, rng=rngs[i], **kwargs) return inputs
def init_fun(rng, input_shape): output_shape = input_shape[:-1] + (out_dim,) keys = split(seed=tf.convert_to_tensor(rng, dtype=tf.int32), num=2) k1 = keys[0] k2 = keys[1] # convert the two keys from shape (2,) into a scalar k1 = stateless_uniform(shape=[], seed=k1, minval=None, maxval=None, dtype=tf.int32) k2 = stateless_uniform(shape=[], seed=k2, minval=None, maxval=None, dtype=tf.int32) W = W_init(seed=k1, shape=(input_shape[-1], out_dim)) b = b_init(seed=k2, shape=(out_dim,)) return tfnp.zeros(output_shape), (W.numpy(), b.numpy())
def apply_fun(params, inputs, **kwargs): rng = kwargs.pop('rng', None) rngs = None if rng is not None: rngs = split(seed=tf.convert_to_tensor(rng, dtype=tf.int32), num=nlayers) else: rngs = (None,) * nlayers result = [] for i in range(len(apply_funs)): result.append(apply_funs[i](params[i], inputs[i], rng=rngs[i], **kwargs)) return result
def init_fun(rng, input_shape): params = [] i = 0 for init_fun in init_funs: i += 1 keys = split(seed=tf.convert_to_tensor(rng, dtype=tf.int32), num=2) rng = keys[0] layer_rng = keys[1] input_shape = shape_conversion(input_shape) input_shape, param = init_fun(layer_rng, input_shape) params.append(param) return input_shape, params
def init_fun(rng, input_shape): input_shape = shape_conversion(input_shape) filter_shape_iter = iter(filter_shape) kernel_shape = [out_chan if c == 'O' else input_shape[lhs_spec.index('C')] if c == 'I' else next(filter_shape_iter) for c in rhs_spec] output_shape = conv_general_shape_tuple( input_shape, kernel_shape, strides, padding, dimension_numbers) bias_shape = [out_chan if c == 'C' else 1 for c in out_spec] bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape)) keys = split(seed=tf.convert_to_tensor(rng, dtype=tf.int32), num=2) k1 = keys[0] k2 = keys[1] W = W_init(seed=k1, shape=kernel_shape) b = b_init(stddev=1e-6, seed=k2, shape=bias_shape) return tfnp.zeros(output_shape), (W, b)
def init_fun(rng, input_shape): rngs = split(seed=tf.convert_to_tensor(rng, dtype=tf.int32), num=nlayers) result = [] for i in range(nlayers): result.append(init_funs[i](rngs[i], input_shape[i])) return zip(*result)