def apply_fun(params, inputs, **kwargs):
   rng = kwargs.pop('rng', None)
   rngs = None
   if rng is not None:
     rngs = split(seed=tf.convert_to_tensor(rng, dtype=tf.int32), num=nlayers)
   else:
     rngs = (None,) * nlayers
   for i in range(nlayers):
     inputs = apply_funs[i](params[i], inputs, rng=rngs[i], **kwargs)
   return inputs
 def init_fun(rng, input_shape):
   output_shape = input_shape[:-1] + (out_dim,)
   keys = split(seed=tf.convert_to_tensor(rng, dtype=tf.int32), num=2)
   k1 = keys[0]
   k2 = keys[1]
   # convert the two keys from shape (2,) into a scalar
   k1 = stateless_uniform(shape=[], seed=k1, minval=None, maxval=None, dtype=tf.int32)
   k2 = stateless_uniform(shape=[], seed=k2, minval=None, maxval=None, dtype=tf.int32)
   W = W_init(seed=k1, shape=(input_shape[-1], out_dim))
   b = b_init(seed=k2, shape=(out_dim,))
   return tfnp.zeros(output_shape), (W.numpy(), b.numpy())
 def apply_fun(params, inputs, **kwargs):
   rng = kwargs.pop('rng', None)
   rngs = None
   if rng is not None:
     rngs = split(seed=tf.convert_to_tensor(rng, dtype=tf.int32), num=nlayers)
   else:
     rngs = (None,) * nlayers
   result = []
   for i in range(len(apply_funs)):
     result.append(apply_funs[i](params[i], inputs[i], rng=rngs[i], **kwargs))
   return result
 def init_fun(rng, input_shape):
   params = []
   i = 0
   for init_fun in init_funs:
     i += 1
     keys = split(seed=tf.convert_to_tensor(rng, dtype=tf.int32), num=2)
     rng = keys[0]
     layer_rng = keys[1]
     input_shape = shape_conversion(input_shape)
     input_shape, param = init_fun(layer_rng, input_shape)
     params.append(param)
   return input_shape, params
 def init_fun(rng, input_shape):
   input_shape = shape_conversion(input_shape)
   filter_shape_iter = iter(filter_shape)
   kernel_shape = [out_chan if c == 'O' else
                   input_shape[lhs_spec.index('C')] if c == 'I' else
                   next(filter_shape_iter) for c in rhs_spec]
   output_shape = conv_general_shape_tuple(
       input_shape, kernel_shape, strides, padding, dimension_numbers)
   bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
   bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape))
   keys = split(seed=tf.convert_to_tensor(rng, dtype=tf.int32), num=2)
   k1 = keys[0]
   k2 = keys[1]
   W = W_init(seed=k1, shape=kernel_shape)
   b = b_init(stddev=1e-6, seed=k2, shape=bias_shape)
   return tfnp.zeros(output_shape), (W, b)
 def init_fun(rng, input_shape):
   rngs = split(seed=tf.convert_to_tensor(rng, dtype=tf.int32), num=nlayers)
   result = []
   for i in range(nlayers):
     result.append(init_funs[i](rngs[i], input_shape[i]))
   return zip(*result)