Exemplo n.º 1
0
def GeneralConv(dimension_numbers,
                out_chan,
                filter_shape,
                strides=None,
                padding='VALID',
                W_gain=1.0,
                W_init=stax.randn(1.0),
                b_gain=0.0,
                b_init=stax.randn(1.0)):
    """Layer construction function for a general convolution layer.

  Uses jax.experimental.stax.GeneralConv as a base.
  """
    lhs_spec, rhs_spec, out_spec = dimension_numbers
    one = (1, ) * len(filter_shape)
    strides = strides or one
    init_fun, _ = stax.GeneralConv(dimension_numbers, out_chan, filter_shape,
                                   strides, padding, W_init, b_init)

    def apply_fun(params, inputs, **kwargs):
        W, b = params
        norm = inputs.shape[lhs_spec.index('C')]
        norm *= functools.reduce(op.mul, filter_shape)
        norm = W_gain / np.sqrt(norm)
        return norm * lax.conv_general_dilated(inputs, W, strides, padding,
                                               one, one,
                                               dimension_numbers) + b_gain * b

    return init_fun, apply_fun
Exemplo n.º 2
0
def decoder(hidden_dim: int, out_dim: int) -> Tuple[Callable, Callable]:
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()),
        stax.Softplus,
        stax.Dense(out_dim, W_init=stax.randn()),
        stax.Sigmoid,
    )
Exemplo n.º 3
0
def decoder(hidden_dim, out_dim):
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()),
        stax.Softplus,
        stax.Dense(out_dim, W_init=stax.randn()),
        stax.Sigmoid,
    )
Exemplo n.º 4
0
def encoder(hidden_dim, z_dim):
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus,
        stax.FanOut(2),
        stax.parallel(stax.Dense(z_dim, W_init=stax.randn()),
                      stax.serial(stax.Dense(z_dim, W_init=stax.randn()), stax.Exp)),
    )
Exemplo n.º 5
0
 def __init__(self, breakpoints):
     self.N_OUTPUTS = len(breakpoints) + 1
     self.breakpoints = np.hstack(([0], breakpoints, [np.inf]))
     self.terminal_layer = [
         stax.Dense(self.N_OUTPUTS,
                    W_init=stax.randn(1e-7),
                    b_init=stax.randn(1e-7)),
         stax.Exp,
     ]
Exemplo n.º 6
0
def encoder(hidden_dim: int, z_dim: int) -> Tuple[Callable, Callable]:
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()),
        stax.Softplus,
        stax.FanOut(2),
        stax.parallel(
            stax.Dense(z_dim, W_init=stax.randn()),
            stax.serial(
                stax.Dense(z_dim, W_init=stax.randn()),
                stax.Exp,
            ),
        ),
    )
Exemplo n.º 7
0
def MaskedDense(mask, bias=True, W_init=glorot(), b_init=randn()):
    """
    As in jax.experimental.stax, each layer constructor function returns
    an (init_fun, apply_fun) pair, where `init_fun` takes an rng key and
    an input shape and returns an (output_shape, params) pair, and
    `apply_fun` takes params, inputs, and an rng key and applies the layer.

    :param array mask: Mask of shape (input_dim, out_dim) applied to the weights of the layer.
    :param bool bias: whether to include bias term.
    :param array W_init: initialization method for the weights.
    :param array b_init: initialization method for the bias terms.
    :return: a (`init_fn`, `update_fn`) pair.
    """
    def init_fun(rng, input_shape):
        k1, k2 = random.split(rng)
        W = W_init(k1, mask.shape)
        if bias:
            b = b_init(k2, mask.shape[-1:])
            params = (W, b)
        else:
            params = W
        return input_shape[:-1] + mask.shape[-1:], params

    def apply_fun(params, inputs, **kwargs):
        if bias:
            W, b = params
            return np.dot(inputs, W * mask) + b
        else:
            W = params
            return np.dot(inputs, W * mask)

    return init_fun, apply_fun
Exemplo n.º 8
0
    def new(input_size, output_size, hidden_layers, key):

        _randn_fn = randn()

        def vector_init(shape):
            if isinstance(shape, int):
                shape = (shape, )
            nonlocal key
            key, rng = random.split(key)
            return _randn_fn(rng, shape)

        _glorot_fn = glorot()

        def matrix_init(shape):
            nonlocal key
            key, rng = random.split(key)
            return _glorot_fn(rng, shape)

        input_state = vector_init(input_size)
        hidden_states = []
        for size in hidden_layers:
            hidden_states.append(vector_init(size))
        output_states = vector_init(output_size)
        states = [input_state, *hidden_states, output_states]

        # weights
        fwd_weights, bwd_weights = [], []
        for prev, post in zip(states[:-1], states[1:]):
            fwd_weights.append(matrix_init((prev.shape[0], post.shape[0])))
            bwd_weights.append(matrix_init((post.shape[0], prev.shape[0])))

        return LayeredNet(states, [*fwd_weights, *bwd_weights])
Exemplo n.º 9
0
 def initialize(cls,
                key,
                in_spec,
                out_chan,
                filter_shape,
                strides=None,
                padding='VALID',
                kernel_init=None,
                bias_init=stax.randn(1e-6),
                use_bias=True):
     in_shape = in_spec.shape
     shapes, inits, (strides, padding,
                     one) = conv_info(in_shape,
                                      out_chan,
                                      filter_shape,
                                      strides=strides,
                                      padding=padding,
                                      kernel_init=kernel_init,
                                      bias_init=bias_init)
     info = ConvInfo(strides, padding, one, use_bias)
     _, kernel_shape, bias_shape = shapes
     kernel_init, bias_init = inits
     k1, k2 = random.split(key)
     if use_bias:
         params = ConvParams(
             base.create_parameter(k1, kernel_shape, init=kernel_init),
             base.create_parameter(k2, bias_shape, init=bias_init),
         )
     else:
         params = ConvParams(
             base.create_parameter(k1, kernel_shape, init=kernel_init),
             None)
     return base.LayerParams(params, info=info)
Exemplo n.º 10
0
def Dense(out_dim,
          W_gain=1.0,
          W_init=stax.randn(1.0),
          b_gain=0.0,
          b_init=stax.randn(1.0)):
    """Layer constructor function for a dense (fully-connected) layer.

  Uses jax.experimental.stax.Dense as a base.
  """
    init_fun, _ = stax.Dense(out_dim, W_init, b_init)

    def apply_fun(params, inputs, **kwargs):
        W, b = params
        norm = W_gain / np.sqrt(inputs.shape[-1])
        return norm * np.dot(inputs, W) + b_gain * b

    return init_fun, apply_fun
Exemplo n.º 11
0
Arquivo: vae.py Projeto: byzhang/d3p
def decoder(hidden_dim, out_dim):
    """Defines the decoder, i.e., the network taking us from latent
        variables back to observations (or at least observation space).

    Network structure:
    z -> dense layer of hidden_dim with softplus activation -> dense layer of out_dim with sigmoid activation

    :param hidden_dim: number of nodes in the hidden layer
    :param out_dim: dimensions of the observations

    :return: (init_fun, apply_fun) pair of the decoder: (decoder_init, decode)
    """
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()),
        stax.Softplus,
        stax.Dense(out_dim, W_init=stax.randn()),
        stax.Sigmoid,
    )
Exemplo n.º 12
0
def Dense(name, out_dim, W_init=stax.glorot(), b_init=stax.randn()):
    """Layer constructor function for a dense (fully-connected) layer."""
    def init_fun(rng, example_input):
        input_shape = example_input.shape
        k1, k2 = random.split(rng)
        W, b = W_init(k1, (out_dim, input_shape[-1])), b_init(k2, (out_dim, ))
        return W, b

    def apply_fun(params, inputs):
        W, b = params
        return np.dot(W, inputs) + b

    return core.Layer(name, init_fun, apply_fun).bind
Exemplo n.º 13
0
Arquivo: vae.py Projeto: byzhang/d3p
def encoder(hidden_dim, z_dim):
    """Defines the encoder, i.e., the network taking us from observations
        to (a distribution of) latent variables.

    z is following a normal distribution, thus needs mean and variance.

    Network structure:
    x -> dense layer of hidden_dim with softplus activation --> dense layer of z_dim ( = means/loc of z)
                                                            |-> dense layer of z_dim with (elementwise) exp() as activation func ( = variance of z )
    (note: the exp() as activation function serves solely to ensure positivity of the variance)

    :param hidden_dim: number of nodes in the hidden layer
    :param z_dim: dimension of the latent variable z
    :return: (init_fun, apply_fun) pair of the encoder: (encoder_init, encode)
    """
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()),
        stax.Softplus,
        stax.FanOut(2),
        stax.parallel(
            stax.Dense(z_dim, W_init=stax.randn()),
            stax.serial(stax.Dense(z_dim, W_init=stax.randn()), stax.Exp)),
    )
Exemplo n.º 14
0
 def spec(cls,
          in_spec,
          out_chan,
          filter_shape,
          strides=None,
          padding='VALID',
          kernel_init=None,
          bias_init=stax.randn(1e-6),
          use_bias=True):
     del use_bias
     in_shape = in_spec.shape
     shapes, _, _ = conv_info(in_shape,
                              out_chan,
                              filter_shape,
                              strides=strides,
                              padding=padding,
                              kernel_init=kernel_init,
                              bias_init=bias_init)
     return state.Shape(shapes[0], dtype=in_spec.dtype)
Exemplo n.º 15
0
def conv_info(in_shape,
              out_chan,
              filter_shape,
              strides=None,
              padding='VALID',
              kernel_init=None,
              bias_init=stax.randn(1e-6),
              transpose=False):
    """Returns parameters and output shape information given input shapes."""
    # Essentially the `stax` implementation
    if len(in_shape) != 3:
        raise ValueError('Need to `jax.vmap` in order to batch')
    in_shape = (1, ) + in_shape
    lhs_spec, rhs_spec, out_spec = DIMENSION_NUMBERS
    one = (1, ) * len(filter_shape)
    strides = strides or one
    kernel_init = kernel_init or stax.glorot(rhs_spec.index('O'),
                                             rhs_spec.index('I'))
    filter_shape_iter = iter(filter_shape)
    kernel_shape = tuple([
        out_chan if c == 'O' else
        in_shape[lhs_spec.index('C')] if c == 'I' else next(filter_shape_iter)
        for c in rhs_spec
    ])
    if transpose:
        out_shape = lax.conv_transpose_shape_tuple(in_shape, kernel_shape,
                                                   strides, padding,
                                                   DIMENSION_NUMBERS)
    else:
        out_shape = lax.conv_general_shape_tuple(in_shape, kernel_shape,
                                                 strides, padding,
                                                 DIMENSION_NUMBERS)
    bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
    bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape))
    out_shape = out_shape[1:]
    shapes = (out_shape, kernel_shape, bias_shape)
    inits = (kernel_init, bias_init)
    return shapes, inits, (strides, padding, one)
Exemplo n.º 16
0
def MaskedDense(out_dim, mask, W_init=glorot(), b_init=randn()):
    """
    As in jax.experimental.stax, each layer constructor function returns
    an (init_fun, apply_fun) pair, where `init_fun` takes an rng key and
    an input shape and returns an (output_shape, params) pair, and
    `apply_fun` takes params, inputs, and an rng key and applies the layer.

    :param int out_dim: Number of output dimensions.
    :param array mask: Mask applied to the weights of the layer.
    :param array W_init: initialization method for the weights.
    :param array b_init: initialization method for the bias terms.
    :return: a (`init_fn`, `update_fn`) pair.
    """
    def init_fun(rng, input_shape):
        output_shape = input_shape[:-1] + (out_dim, )
        k1, k2 = random.split(rng)
        W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim, ))
        return output_shape, (W, b)

    def apply_fun(params, inputs, **kwargs):
        W, b = params
        return np.dot(inputs, W * mask) + b

    return init_fun, apply_fun
Exemplo n.º 17
0
 def __init__(self, out_dim, W_init=stax.glorot(), b_init=stax.randn()):
     super(Dense, self).__init__()
     self.out_dim = out_dim
     self.W_init = W_init
     self.b_init = b_init
Exemplo n.º 18
0
 def testRandnInitShape(self, shape):
     key = random.PRNGKey(0)
     out = stax.randn()(key, shape)
     self.assertEqual(out.shape, shape)
Exemplo n.º 19
0
 def testRandnInitShape(self, shape):
     out = stax.randn()(shape)
     self.assertEqual(out.shape, shape)
Exemplo n.º 20
0
 def __init__(self):
     self.terminal_layer = [
         stax.Dense(self.N_OUTPUTS,
                    W_init=stax.randn(1e-10),
                    b_init=stax.randn(1e-10))
     ]