Ejemplo n.º 1
0
def decoder(hidden_dim, out_dim):
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()),
        stax.Softplus,
        stax.Dense(out_dim, W_init=stax.randn()),
        stax.Sigmoid,
    )
Ejemplo n.º 2
0
def encoder(hidden_dim, z_dim):
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()),
        stax.Softplus,
        stax.FanOut(2),
        stax.parallel(
            stax.Dense(z_dim, W_init=stax.randn()),
            stax.serial(stax.Dense(z_dim, W_init=stax.randn()), stax.Exp),
        ),
    )
Ejemplo n.º 3
0
 def initialize(cls,
                key,
                in_spec,
                out_chan,
                filter_shape,
                strides=None,
                padding='VALID',
                kernel_init=None,
                bias_init=stax.randn(1e-6),
                use_bias=True):
     in_shape = in_spec.shape
     shapes, inits, (strides, padding,
                     one) = conv_info(in_shape,
                                      out_chan,
                                      filter_shape,
                                      strides=strides,
                                      padding=padding,
                                      kernel_init=kernel_init,
                                      bias_init=bias_init)
     info = ConvInfo(strides, padding, one, use_bias)
     _, kernel_shape, bias_shape = shapes
     kernel_init, bias_init = inits
     k1, k2 = random.split(key)
     if use_bias:
         params = ConvParams(
             base.create_parameter(k1, kernel_shape, init=kernel_init),
             base.create_parameter(k2, bias_shape, init=bias_init),
         )
     else:
         params = ConvParams(
             base.create_parameter(k1, kernel_shape, init=kernel_init),
             None)
     return base.LayerParams(params, info=info)
Ejemplo n.º 4
0
 def spec(cls,
          in_spec,
          out_chan,
          filter_shape,
          strides=None,
          padding='VALID',
          kernel_init=None,
          bias_init=stax.randn(1e-6),
          use_bias=True):
     del use_bias
     in_shape = in_spec.shape
     shapes, _, _ = conv_info(in_shape,
                              out_chan,
                              filter_shape,
                              strides=strides,
                              padding=padding,
                              kernel_init=kernel_init,
                              bias_init=bias_init)
     return state.Shape(shapes[0], dtype=in_spec.dtype)
Ejemplo n.º 5
0
def conv_info(in_shape,
              out_chan,
              filter_shape,
              strides=None,
              padding='VALID',
              kernel_init=None,
              bias_init=stax.randn(1e-6),
              transpose=False):
    """Returns parameters and output shape information given input shapes."""
    # Essentially the `stax` implementation
    if len(in_shape) != 3:
        raise ValueError('Need to `jax.vmap` in order to batch')
    in_shape = (1, ) + in_shape
    lhs_spec, rhs_spec, out_spec = DIMENSION_NUMBERS
    one = (1, ) * len(filter_shape)
    strides = strides or one
    kernel_init = kernel_init or stax.glorot(rhs_spec.index('O'),
                                             rhs_spec.index('I'))
    filter_shape_iter = iter(filter_shape)
    kernel_shape = tuple([
        out_chan if c == 'O' else
        in_shape[lhs_spec.index('C')] if c == 'I' else next(filter_shape_iter)
        for c in rhs_spec
    ])
    if transpose:
        out_shape = lax.conv_transpose_shape_tuple(in_shape, kernel_shape,
                                                   strides, padding,
                                                   DIMENSION_NUMBERS)
    else:
        out_shape = lax.conv_general_shape_tuple(in_shape, kernel_shape,
                                                 strides, padding,
                                                 DIMENSION_NUMBERS)
    bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
    bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape))
    out_shape = out_shape[1:]
    shapes = (out_shape, kernel_shape, bias_shape)
    inits = (kernel_init, bias_init)
    return shapes, inits, (strides, padding, one)
Ejemplo n.º 6
0
 def testRandnInitShape(self, shape):
     key = random.PRNGKey(0)
     out = stax.randn()(key, shape)
     self.assertEqual(out.shape, shape)