Beispiel #1
0
    def initialize(cls,
                   rng,
                   in_spec,
                   dim_out,
                   kernel_init=stax.glorot(),
                   bias_init=stax.zeros):
        """Initializes Dense Layer.

    Args:
      rng: Random key.
      in_spec: Input Spec.
      dim_out: Output dimensions.
      kernel_init: Kernel initialization function.
      bias_init: Bias initialization function.

    Returns:
      Tuple with the output shape and the LayerParams.
    """
        if rng is None:
            raise ValueError('Need valid RNG to instantiate Dense layer.')
        dim_in = in_spec.shape[-1]
        k1, k2 = random.split(rng)
        params = DenseParams(
            base.create_parameter(k1, (dim_in, dim_out), init=kernel_init),
            base.create_parameter(k2, (dim_out, ), init=bias_init))
        return base.LayerParams(params)
Beispiel #2
0
 def initialize(cls,
                key,
                in_spec,
                out_chan,
                filter_shape,
                strides=None,
                padding='VALID',
                kernel_init=None,
                bias_init=stax.randn(1e-6),
                use_bias=True):
     in_shape = in_spec.shape
     shapes, inits, (strides, padding,
                     one) = conv_info(in_shape,
                                      out_chan,
                                      filter_shape,
                                      strides=strides,
                                      padding=padding,
                                      kernel_init=kernel_init,
                                      bias_init=bias_init)
     info = ConvInfo(strides, padding, one, use_bias)
     _, kernel_shape, bias_shape = shapes
     kernel_init, bias_init = inits
     k1, k2 = random.split(key)
     if use_bias:
         params = ConvParams(
             base.create_parameter(k1, kernel_shape, init=kernel_init),
             base.create_parameter(k2, bias_shape, init=bias_init),
         )
     else:
         params = ConvParams(
             base.create_parameter(k1, kernel_shape, init=kernel_init),
             None)
     return base.LayerParams(params, info=info)
Beispiel #3
0
 def initialize(cls,
                key,
                in_spec,
                axis=(0, 1),
                momentum=0.99,
                epsilon=1e-5,
                center=True,
                scale=True,
                beta_init=stax.zeros,
                gamma_init=stax.ones):
     in_shape = in_spec.shape
     axis = (axis, ) if np.isscalar(axis) else axis
     decay = 1.0 - momentum
     shape = tuple(d for i, d in enumerate(in_shape) if i not in axis)
     moving_shape = tuple(1 if i in axis else d
                          for i, d in enumerate(in_shape))
     k1, k2, k3, k4 = random.split(key, 4)
     beta = base.create_parameter(k1, shape,
                                  init=beta_init) if center else ()
     gamma = base.create_parameter(k2, shape,
                                   init=gamma_init) if scale else ()
     moving_mean = base.create_parameter(k3, moving_shape, init=stax.zeros)
     moving_var = base.create_parameter(k4, moving_shape, init=stax.ones)
     params = BatchNormParams(beta, gamma)
     info = BatchNormInfo(axis, epsilon, center, scale, decay, in_shape)
     state = BatchNormState(moving_mean, moving_var)
     return base.LayerParams(params, info, state)