class GRUCell(RNNCellBase): r"""GRU cell. the mathematical definition of the cell is as follows .. math:: \begin{array}{ll} r = \sigma(W_{ir} x + W_{hr} h + b_{hr}) \\ z = \sigma(W_{iz} x + W_{hz} h + b_{hz}) \\ n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ h' = (1 - z) * n + z * h \end{array} where x is the input and h, is the output of the previous time step. Attributes: gate_fn: activation function used for gates (default: sigmoid) activation_fn: activation function used for output and memory update (default: tanh). kernel_init: initializer function for the kernels that transform the input (default: lecun_normal). recurrent_kernel_init: initializer function for the kernels that transform the hidden state (default: orthogonal). bias_init: initializer for the bias parameters (default: zeros) """ gate_fn: Callable[..., Any] = sigmoid activation_fn: Callable[..., Any] = tanh kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = (default_kernel_init) recurrent_kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = (orthogonal()) bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros @compact def __call__(self, carry, inputs): """Gated recurrent unit (GRU) cell. Args: carry: the hidden state of the LSTM cell, initialized using `GRUCell.initialize_carry`. inputs: an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions. Returns: A tuple with the new carry and the output. """ h = carry hidden_features = h.shape[-1] # input and recurrent layers are summed so only one needs a bias. dense_h = partial(Dense, features=hidden_features, use_bias=False, kernel_init=self.recurrent_kernel_init, bias_init=self.bias_init) dense_i = partial(Dense, features=hidden_features, use_bias=True, kernel_init=self.kernel_init, bias_init=self.bias_init) r = self.gate_fn(dense_i(name='ir')(inputs) + dense_h(name='hr')(h)) z = self.gate_fn(dense_i(name='iz')(inputs) + dense_h(name='hz')(h)) # add bias because the linear transformations aren't directly summed. n = self.activation_fn( dense_i(name='in')(inputs) + r * dense_h(name='hn', use_bias=True)(h)) new_h = (1. - z) * n + z * h return new_h, new_h @staticmethod def initialize_carry(rng, batch_dims, size, init_fn=zeros): """initialize the RNN cell carry. Args: rng: random number generator passed to the init_fn. batch_dims: a tuple providing the shape of the batch dimensions. size: the size or number of features of the memory. init_fn: initializer function for the carry. Returns: An initialized carry for the given RNN cell. """ mem_shape = batch_dims + (size, ) return init_fn(rng, mem_shape)
class LSTMCell(RNNCellBase): r"""LSTM cell. the mathematical definition of the cell is as follows .. math:: \begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array} where x is the input, h is the output of the previous time step, and c is the memory. Attributes: gate_fn: activation function used for gates (default: sigmoid) activation_fn: activation function used for output and memory update (default: tanh). kernel_init: initializer function for the kernels that transform the input (default: lecun_normal). recurrent_kernel_init: initializer function for the kernels that transform the hidden state (default: orthogonal). bias_init: initializer for the bias parameters (default: zeros) """ gate_fn: Callable[..., Any] = sigmoid activation_fn: Callable[..., Any] = tanh kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init recurrent_kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = orthogonal() bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros @compact def __call__(self, carry, inputs): r"""A long short-term memory (LSTM) cell. Args: carry: the hidden state of the LSTM cell, initialized using `LSTMCell.initialize_carry`. inputs: an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions. Returns: A tuple with the new carry and the output. """ c, h = carry hidden_features = h.shape[-1] # input and recurrent layers are summed so only one needs a bias. dense_h = partial(Dense, features=hidden_features, use_bias=True, kernel_init=self.recurrent_kernel_init, bias_init=self.bias_init) dense_i = partial(Dense, features=hidden_features, use_bias=False, kernel_init=self.kernel_init) i = self.gate_fn(dense_i(name='ii')(inputs) + dense_h(name='hi')(h)) f = self.gate_fn(dense_i(name='if')(inputs) + dense_h(name='hf')(h)) g = self.activation_fn( dense_i(name='ig')(inputs) + dense_h(name='hg')(h)) o = self.gate_fn(dense_i(name='io')(inputs) + dense_h(name='ho')(h)) new_c = f * c + i * g new_h = o * self.activation_fn(new_c) return (new_c, new_h), new_h @staticmethod def initialize_carry(rng, batch_dims, size, init_fn=zeros): """initialize the RNN cell carry. Args: rng: random number generator passed to the init_fn. batch_dims: a tuple providing the shape of the batch dimensions. size: the size or number of features of the memory. init_fn: initializer function for the carry. Returns: An initialized carry for the given RNN cell. """ key1, key2 = random.split(rng) mem_shape = batch_dims + (size, ) return init_fn(key1, mem_shape), init_fn(key2, mem_shape)
class OptimizedLSTMCell(RNNCellBase): r"""More efficient LSTM Cell that concatenates state components before matmul. The parameters are compatible with `LSTMCell`. Note that this cell is often faster than `LSTMCell` as long as the hidden size is roughly <= 2048 units. The mathematical definition of the cell is the same as `LSTMCell` and as follows .. math:: \begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array} where x is the input, h is the output of the previous time step, and c is the memory. Args: gate_fn: activation function used for gates (default: sigmoid). activation_fn: activation function used for output and memory update (default: tanh). kernel_init: initializer function for the kernels that transform the input (default: lecun_normal). recurrent_kernel_init: initializer function for the kernels that transform the hidden state (default: orthogonal). bias_init: initializer for the bias parameters (default: zeros). """ gate_fn: Callable = sigmoid activation_fn: Callable = tanh kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init recurrent_kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = orthogonal() bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros @compact def __call__(self, carry: Tuple[Array, Array], inputs: Array) -> Tuple[Tuple[Array, Array], Array]: r"""An optimized long short-term memory (LSTM) cell. Args: carry: the hidden state of the LSTM cell, initialized using `LSTMCell.initialize_carry`. inputs: an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions. Returns: A tuple with the new carry and the output. """ c, h = carry hidden_features = h.shape[-1] def _concat_dense(inputs, params, use_bias=True): """ Concatenates the individual kernels and biases, given in params, into a single kernel and single bias for efficiency before applying them using dot_general. """ kernels, biases = zip(*params.values()) kernel = jnp.asarray(jnp.concatenate(kernels, axis=-1), jnp.float32) y = jnp.dot(inputs, kernel) if use_bias: bias = jnp.asarray(jnp.concatenate(biases, axis=-1), jnp.float32) y = y + bias # Split the result back into individual (i, f, g, o) outputs. split_indices = np.cumsum([b.shape[0] for b in biases[:-1]]) ys = jnp.split(y, split_indices, axis=-1) return dict(zip(params.keys(), ys)) # Create params with the same names/shapes as `LSTMCell` for compatibility. dense_params_h = {} dense_params_i = {} for component in ['i', 'f', 'g', 'o']: dense_params_i[component] = DenseParams( features=hidden_features, use_bias=False, kernel_init=self.kernel_init, bias_init=self.bias_init, name=f'i{component}')(inputs) dense_params_h[component] = DenseParams( features=hidden_features, use_bias=True, kernel_init=self.recurrent_kernel_init, bias_init=self.bias_init, name=f'h{component}')(h) dense_h = _concat_dense(h, dense_params_h, use_bias=True) dense_i = _concat_dense(inputs, dense_params_i, use_bias=False) i = self.gate_fn(dense_h['i'] + dense_i['i']) f = self.gate_fn(dense_h['f'] + dense_i['f']) g = self.activation_fn(dense_h['g'] + dense_i['g']) o = self.gate_fn(dense_h['o'] + dense_i['o']) new_c = f * c + i * g new_h = o * self.activation_fn(new_c) return (new_c, new_h), new_h @staticmethod def initialize_carry(rng, batch_dims, size, init_fn=zeros): """initialize the RNN cell carry. Args: rng: random number generator passed to the init_fn. batch_dims: a tuple providing the shape of the batch dimensions. size: the size or number of features of the memory. init_fn: initializer function for the carry. Returns: An initialized carry for the given RNN cell. """ key1, key2 = random.split(rng) mem_shape = batch_dims + (size, ) return init_fn(key1, mem_shape), init_fn(key2, mem_shape)