Example #1
0
def conv_info(in_shape,
              out_chan,
              filter_shape,
              strides=None,
              padding='VALID',
              kernel_init=None,
              bias_init=stax.randn(1e-6),
              transpose=False):
    """Returns parameters and output shape information given input shapes."""
    # Essentially the `stax` implementation
    if len(in_shape) != 3:
        raise ValueError('Need to `jax.vmap` in order to batch')
    in_shape = (1, ) + in_shape
    lhs_spec, rhs_spec, out_spec = DIMENSION_NUMBERS
    one = (1, ) * len(filter_shape)
    strides = strides or one
    kernel_init = kernel_init or stax.glorot(rhs_spec.index('O'),
                                             rhs_spec.index('I'))
    filter_shape_iter = iter(filter_shape)
    kernel_shape = tuple([
        out_chan if c == 'O' else
        in_shape[lhs_spec.index('C')] if c == 'I' else next(filter_shape_iter)
        for c in rhs_spec
    ])
    if transpose:
        out_shape = lax.conv_transpose_shape_tuple(in_shape, kernel_shape,
                                                   strides, padding,
                                                   DIMENSION_NUMBERS)
    else:
        out_shape = lax.conv_general_shape_tuple(in_shape, kernel_shape,
                                                 strides, padding,
                                                 DIMENSION_NUMBERS)
    bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
    bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape))
    out_shape = out_shape[1:]
    shapes = (out_shape, kernel_shape, bias_shape)
    inits = (kernel_init, bias_init)
    return shapes, inits, (strides, padding, one)
Example #2
0
  def initialize(cls, rng, in_spec, dim_out,
                 kernel_init=stax.glorot(),
                 bias_init=stax.zeros):
    """Initializes Dense Layer.

    Args:
      rng: Random key.
      in_spec: Input Spec.
      dim_out: Output dimensions.
      kernel_init: Kernel initialization function.
      bias_init: Bias initialization function.

    Returns:
      Tuple with the output shape and the LayerParams.
    """
    if rng is None:
      raise ValueError('Need valid RNG to instantiate Dense layer.')
    dim_in = in_spec.shape[-1]
    k1, k2 = random.split(rng)
    params = DenseParams(
        base.create_parameter(k1, (dim_in, dim_out), init=kernel_init),
        base.create_parameter(k2, (dim_out,), init=bias_init)
    )
    return base.LayerParams(params)
Example #3
0
 def testGlorotInitShape(self, shape):
     key = random.PRNGKey(0)
     out = stax.glorot()(key, shape)
     self.assertEqual(out.shape, shape)
Example #4
0
def create_parameter(rng, spec, init=stax.glorot()):
    return init(rng, spec)