def initialize(cls, rng, in_spec, dim_out, kernel_init=stax.glorot(), bias_init=stax.zeros): """Initializes Dense Layer. Args: rng: Random key. in_spec: Input Spec. dim_out: Output dimensions. kernel_init: Kernel initialization function. bias_init: Bias initialization function. Returns: Tuple with the output shape and the LayerParams. """ if rng is None: raise ValueError('Need valid RNG to instantiate Dense layer.') dim_in = in_spec.shape[-1] k1, k2 = random.split(rng) params = DenseParams( base.create_parameter(k1, (dim_in, dim_out), init=kernel_init), base.create_parameter(k2, (dim_out, ), init=bias_init)) return base.LayerParams(params)
def initialize(cls, key, in_spec, out_chan, filter_shape, strides=None, padding='VALID', kernel_init=None, bias_init=stax.randn(1e-6), use_bias=True): in_shape = in_spec.shape shapes, inits, (strides, padding, one) = conv_info(in_shape, out_chan, filter_shape, strides=strides, padding=padding, kernel_init=kernel_init, bias_init=bias_init) info = ConvInfo(strides, padding, one, use_bias) _, kernel_shape, bias_shape = shapes kernel_init, bias_init = inits k1, k2 = random.split(key) if use_bias: params = ConvParams( base.create_parameter(k1, kernel_shape, init=kernel_init), base.create_parameter(k2, bias_shape, init=bias_init), ) else: params = ConvParams( base.create_parameter(k1, kernel_shape, init=kernel_init), None) return base.LayerParams(params, info=info)
def initialize(cls, key, in_spec, axis=(0, 1), momentum=0.99, epsilon=1e-5, center=True, scale=True, beta_init=stax.zeros, gamma_init=stax.ones): in_shape = in_spec.shape axis = (axis, ) if np.isscalar(axis) else axis decay = 1.0 - momentum shape = tuple(d for i, d in enumerate(in_shape) if i not in axis) moving_shape = tuple(1 if i in axis else d for i, d in enumerate(in_shape)) k1, k2, k3, k4 = random.split(key, 4) beta = base.create_parameter(k1, shape, init=beta_init) if center else () gamma = base.create_parameter(k2, shape, init=gamma_init) if scale else () moving_mean = base.create_parameter(k3, moving_shape, init=stax.zeros) moving_var = base.create_parameter(k4, moving_shape, init=stax.ones) params = BatchNormParams(beta, gamma) info = BatchNormInfo(axis, epsilon, center, scale, decay, in_shape) state = BatchNormState(moving_mean, moving_var) return base.LayerParams(params, info, state)