def decoder(hidden_dim, out_dim): return stax.serial( stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus, stax.Dense(out_dim, W_init=stax.randn()), stax.Sigmoid, )
def encoder(hidden_dim, z_dim): return stax.serial( stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus, stax.FanOut(2), stax.parallel( stax.Dense(z_dim, W_init=stax.randn()), stax.serial(stax.Dense(z_dim, W_init=stax.randn()), stax.Exp), ), )
def initialize(cls, key, in_spec, out_chan, filter_shape, strides=None, padding='VALID', kernel_init=None, bias_init=stax.randn(1e-6), use_bias=True): in_shape = in_spec.shape shapes, inits, (strides, padding, one) = conv_info(in_shape, out_chan, filter_shape, strides=strides, padding=padding, kernel_init=kernel_init, bias_init=bias_init) info = ConvInfo(strides, padding, one, use_bias) _, kernel_shape, bias_shape = shapes kernel_init, bias_init = inits k1, k2 = random.split(key) if use_bias: params = ConvParams( base.create_parameter(k1, kernel_shape, init=kernel_init), base.create_parameter(k2, bias_shape, init=bias_init), ) else: params = ConvParams( base.create_parameter(k1, kernel_shape, init=kernel_init), None) return base.LayerParams(params, info=info)
def spec(cls, in_spec, out_chan, filter_shape, strides=None, padding='VALID', kernel_init=None, bias_init=stax.randn(1e-6), use_bias=True): del use_bias in_shape = in_spec.shape shapes, _, _ = conv_info(in_shape, out_chan, filter_shape, strides=strides, padding=padding, kernel_init=kernel_init, bias_init=bias_init) return state.Shape(shapes[0], dtype=in_spec.dtype)
def conv_info(in_shape, out_chan, filter_shape, strides=None, padding='VALID', kernel_init=None, bias_init=stax.randn(1e-6), transpose=False): """Returns parameters and output shape information given input shapes.""" # Essentially the `stax` implementation if len(in_shape) != 3: raise ValueError('Need to `jax.vmap` in order to batch') in_shape = (1, ) + in_shape lhs_spec, rhs_spec, out_spec = DIMENSION_NUMBERS one = (1, ) * len(filter_shape) strides = strides or one kernel_init = kernel_init or stax.glorot(rhs_spec.index('O'), rhs_spec.index('I')) filter_shape_iter = iter(filter_shape) kernel_shape = tuple([ out_chan if c == 'O' else in_shape[lhs_spec.index('C')] if c == 'I' else next(filter_shape_iter) for c in rhs_spec ]) if transpose: out_shape = lax.conv_transpose_shape_tuple(in_shape, kernel_shape, strides, padding, DIMENSION_NUMBERS) else: out_shape = lax.conv_general_shape_tuple(in_shape, kernel_shape, strides, padding, DIMENSION_NUMBERS) bias_shape = [out_chan if c == 'C' else 1 for c in out_spec] bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape)) out_shape = out_shape[1:] shapes = (out_shape, kernel_shape, bias_shape) inits = (kernel_init, bias_init) return shapes, inits, (strides, padding, one)
def testRandnInitShape(self, shape): key = random.PRNGKey(0) out = stax.randn()(key, shape) self.assertEqual(out.shape, shape)