Beispiel #1
class GRUCell(RNNCellBase):
    r"""GRU cell.

  the mathematical definition of the cell is as follows
  .. math::
      r = \sigma(W_{ir} x + W_{hr} h + b_{hr}) \\
      z = \sigma(W_{iz} x + W_{hz} h + b_{hz}) \\
      n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
      h' = (1 - z) * n + z * h
  where x is the input and h, is the output of the previous time step.

    gate_fn: activation function used for gates (default: sigmoid)
    activation_fn: activation function used for output and memory update
      (default: tanh).
    kernel_init: initializer function for the kernels that transform
      the input (default: lecun_normal).
    recurrent_kernel_init: initializer function for the kernels that transform
      the hidden state (default: orthogonal).
    bias_init: initializer for the bias parameters (default: zeros)
    gate_fn: Callable[..., Any] = sigmoid
    activation_fn: Callable[..., Any] = tanh
    kernel_init: Callable[[PRNGKey, Shape, Dtype],
                          Array] = (default_kernel_init)
    recurrent_kernel_init: Callable[[PRNGKey, Shape, Dtype],
                                    Array] = (orthogonal())
    bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros

    def __call__(self, carry, inputs):
        """Gated recurrent unit (GRU) cell.

      carry: the hidden state of the LSTM cell,
        initialized using `GRUCell.initialize_carry`.
      inputs: an ndarray with the input for the current time step.
        All dimensions except the final are considered batch dimensions.

      A tuple with the new carry and the output.
        h = carry
        hidden_features = h.shape[-1]
        # input and recurrent layers are summed so only one needs a bias.
        dense_h = partial(Dense,
        dense_i = partial(Dense,
        r = self.gate_fn(dense_i(name='ir')(inputs) + dense_h(name='hr')(h))
        z = self.gate_fn(dense_i(name='iz')(inputs) + dense_h(name='hz')(h))
        # add bias because the linear transformations aren't directly summed.
        n = self.activation_fn(
            dense_i(name='in')(inputs) +
            r * dense_h(name='hn', use_bias=True)(h))
        new_h = (1. - z) * n + z * h
        return new_h, new_h

    def initialize_carry(rng, batch_dims, size, init_fn=zeros):
        """initialize the RNN cell carry.

      rng: random number generator passed to the init_fn.
      batch_dims: a tuple providing the shape of the batch dimensions.
      size: the size or number of features of the memory.
      init_fn: initializer function for the carry.
      An initialized carry for the given RNN cell.
        mem_shape = batch_dims + (size, )
        return init_fn(rng, mem_shape)
Beispiel #2
class LSTMCell(RNNCellBase):
    r"""LSTM cell.
    the mathematical definition of the cell is as follows
  .. math::
      i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\
      f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\
      g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\
      o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\
      c' = f * c + i * g \\
      h' = o * \tanh(c') \\
  where x is the input, h is the output of the previous time step, and c is
  the memory.

    gate_fn: activation function used for gates (default: sigmoid)
    activation_fn: activation function used for output and memory update
      (default: tanh).
    kernel_init: initializer function for the kernels that transform
      the input (default: lecun_normal).
    recurrent_kernel_init: initializer function for the kernels that transform
      the hidden state (default: orthogonal).
    bias_init: initializer for the bias parameters (default: zeros)
    gate_fn: Callable[..., Any] = sigmoid
    activation_fn: Callable[..., Any] = tanh
    kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
    recurrent_kernel_init: Callable[[PRNGKey, Shape, Dtype],
                                    Array] = orthogonal()
    bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros

    def __call__(self, carry, inputs):
        r"""A long short-term memory (LSTM) cell.

      carry: the hidden state of the LSTM cell,
        initialized using `LSTMCell.initialize_carry`.
      inputs: an ndarray with the input for the current time step.
        All dimensions except the final are considered batch dimensions.

      A tuple with the new carry and the output.
        c, h = carry
        hidden_features = h.shape[-1]
        # input and recurrent layers are summed so only one needs a bias.
        dense_h = partial(Dense,
        dense_i = partial(Dense,
        i = self.gate_fn(dense_i(name='ii')(inputs) + dense_h(name='hi')(h))
        f = self.gate_fn(dense_i(name='if')(inputs) + dense_h(name='hf')(h))
        g = self.activation_fn(
            dense_i(name='ig')(inputs) + dense_h(name='hg')(h))
        o = self.gate_fn(dense_i(name='io')(inputs) + dense_h(name='ho')(h))
        new_c = f * c + i * g
        new_h = o * self.activation_fn(new_c)
        return (new_c, new_h), new_h

    def initialize_carry(rng, batch_dims, size, init_fn=zeros):
        """initialize the RNN cell carry.

      rng: random number generator passed to the init_fn.
      batch_dims: a tuple providing the shape of the batch dimensions.
      size: the size or number of features of the memory.
      init_fn: initializer function for the carry.
      An initialized carry for the given RNN cell.
        key1, key2 = random.split(rng)
        mem_shape = batch_dims + (size, )
        return init_fn(key1, mem_shape), init_fn(key2, mem_shape)
Beispiel #3
class OptimizedLSTMCell(RNNCellBase):
    r"""More efficient LSTM Cell that concatenates state components before matmul.

  The parameters are compatible with `LSTMCell`. Note that this cell is often
  faster than `LSTMCell` as long as the hidden size is roughly <= 2048 units.

  The mathematical definition of the cell is the same as `LSTMCell` and as follows
  .. math::
      i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\
      f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\
      g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\
      o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\
      c' = f * c + i * g \\
      h' = o * \tanh(c') \\
  where x is the input, h is the output of the previous time step, and c is
  the memory.

    gate_fn: activation function used for gates (default: sigmoid).
    activation_fn: activation function used for output and memory update
      (default: tanh).
    kernel_init: initializer function for the kernels that transform
      the input (default: lecun_normal).
    recurrent_kernel_init: initializer function for the kernels that transform
      the hidden state (default: orthogonal).
    bias_init: initializer for the bias parameters (default: zeros).
    gate_fn: Callable = sigmoid
    activation_fn: Callable = tanh
    kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
    recurrent_kernel_init: Callable[[PRNGKey, Shape, Dtype],
                                    Array] = orthogonal()
    bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros

    def __call__(self, carry: Tuple[Array, Array],
                 inputs: Array) -> Tuple[Tuple[Array, Array], Array]:
        r"""An optimized long short-term memory (LSTM) cell.

      carry: the hidden state of the LSTM cell, initialized using
      inputs: an ndarray with the input for the current time step. All
        dimensions except the final are considered batch dimensions.

      A tuple with the new carry and the output.
        c, h = carry
        hidden_features = h.shape[-1]

        def _concat_dense(inputs, params, use_bias=True):
      Concatenates the individual kernels and biases, given in params, into a 
      single kernel and single bias for efficiency before applying them using 
            kernels, biases = zip(*params.values())
            kernel = jnp.asarray(jnp.concatenate(kernels, axis=-1),

            y =, kernel)
            if use_bias:
                bias = jnp.asarray(jnp.concatenate(biases, axis=-1),
                y = y + bias

            # Split the result back into individual (i, f, g, o) outputs.
            split_indices = np.cumsum([b.shape[0] for b in biases[:-1]])
            ys = jnp.split(y, split_indices, axis=-1)
            return dict(zip(params.keys(), ys))

        # Create params with the same names/shapes as `LSTMCell` for compatibility.
        dense_params_h = {}
        dense_params_i = {}
        for component in ['i', 'f', 'g', 'o']:
            dense_params_i[component] = DenseParams(
            dense_params_h[component] = DenseParams(
        dense_h = _concat_dense(h, dense_params_h, use_bias=True)
        dense_i = _concat_dense(inputs, dense_params_i, use_bias=False)

        i = self.gate_fn(dense_h['i'] + dense_i['i'])
        f = self.gate_fn(dense_h['f'] + dense_i['f'])
        g = self.activation_fn(dense_h['g'] + dense_i['g'])
        o = self.gate_fn(dense_h['o'] + dense_i['o'])

        new_c = f * c + i * g
        new_h = o * self.activation_fn(new_c)
        return (new_c, new_h), new_h

    def initialize_carry(rng, batch_dims, size, init_fn=zeros):
        """initialize the RNN cell carry.

      rng: random number generator passed to the init_fn.
      batch_dims: a tuple providing the shape of the batch dimensions.
      size: the size or number of features of the memory.
      init_fn: initializer function for the carry.

      An initialized carry for the given RNN cell.
        key1, key2 = random.split(rng)
        mem_shape = batch_dims + (size, )
        return init_fn(key1, mem_shape), init_fn(key2, mem_shape)