def BenesBlock(d_model, dropout, mode): def bit_sequence(inputs): seq_length = inputs.shape[1] n_bits = np.int32(np.log(seq_length - 1) / np.log(2.0)) + 1 return jnp.arange(0, n_bits) return tl.Serial( tl.Dup(), tl.Fn('BitSeq', bit_sequence, n_out=1), tl.Scan(_ForwardStep(d_model, dropout, mode)), tl.Scan(_BackwardStep(d_model, dropout, mode)), tl.Select([1]), )
def test_multi_input(self, backend): def _MultiInputFn(): # pylint: disable=invalid-name def f(a, b, carry): return a + b, b, carry + 1 return tl.Fn('MultiInputFn', f, n_out=2) with fastmath.use_backend(backend): layer = tl.Scan(_MultiInputFn(), axis=1) xs = [ np.array([[0, 1, 2], [0, 10, 20]]), np.array([[4, 5, 6], [40, 50, 60]]), np.array([9000, 8000]) ] ys = layer(xs) self.assertEqual(as_list(ys), [[[4, 6, 8], [40, 60, 80]], [[4, 5, 6], [40, 50, 60]], [9003, 8003] ])
def test_no_carry(self): def _AddOne(): # pylint: disable=invalid-name return tl.Fn('AddOne', lambda x: x + 1) layer = tl.Scan(_AddOne(), n_carry=0) x = np.array([[1, 3, 7], [10, 30, 70]]) y = layer(x) self.assertEqual(as_list(y), [[2, 4, 8], [11, 31, 71]])
def test_predict(self, backend): with fastmath.use_backend(backend): layer = tl.Scan(self._AddWithCarry(), axis=1, mode='predict') xs = [np.array([[0, 1, 2]]), np.array([90])] ys = layer(xs) self.assertEqual(as_list(ys), [[[90, 91, 93]], [93]]) xs = [np.array([[3, 4]]), np.array([90])] ys = layer(xs) self.assertEqual(as_list(ys), [[[96, 100]], [100]])
def test_axis_1(self): layer = tl.Scan(self._AddWithCarry(), axis=1) xs = [ np.array([[0, 1, 2, 3], [0, 10, 20, 30], [0, 100, 200, 300]]), np.array([9000, 8000, 7000]) ] ys = layer(xs) self.assertEqual(as_list(ys), [[[9000, 9001, 9003, 9006], [8000, 8010, 8030, 8060], [7000, 7100, 7300, 7600]], [9006, 8060, 7600]])
def test_default_axis(self): layer = tl.Scan(self._AddWithCarry()) xs = [ np.array([[0, 1, 2, 3], [0, 10, 20, 30], [0, 100, 200, 300]]), np.array([9000, 8000, 7000, 6000]) ] ys = layer(xs) self.assertEqual( as_list(ys), [[[9000, 8001, 7002, 6003], [9000, 8011, 7022, 6033], [9000, 8111, 7222, 6333]], [9000, 8111, 7222, 6333]])
def RNNLM(vocab_size, d_model=512, n_layers=2, rnn_cell=tl.LSTMCell, rnn_cell_d_state_multiplier=2, dropout=0.1, mode='train'): """Returns an RNN language model. The input to the model is a tensor of tokens (ints). Args: vocab_size: int: vocab size d_model: int: depth of embedding (n_units in the RNN cell) n_layers: int: number of RNN layers rnn_cell: the RNN cell rnn_cell_d_state_multiplier: how many times is RNN cell state larger dropout: float: dropout rate (how much to drop out) mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference Returns: An RNN language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ def MultiRNNCell(): """Multi-layer RNN cell.""" assert n_layers == 2 return tl.Serial( tl.Parallel([], tl.Split(n_items=n_layers)), tl.SerialWithSideOutputs( [rnn_cell(n_units=d_model) for _ in range(n_layers)]), tl.Parallel([], tl.Concatenate(n_items=n_layers)) ) zero_state = tl.MakeZeroState( # pylint: disable=no-value-for-parameter depth_multiplier=n_layers * rnn_cell_d_state_multiplier ) return tl.Serial( tl.ShiftRight(mode=mode), tl.Embedding(d_model, vocab_size), tl.Dropout(rate=dropout, mode=mode), tl.Branch([], zero_state), tl.Scan(MultiRNNCell(), axis=1), tl.Select([0], n_in=2), # Drop RNN state. tl.Dense(vocab_size), tl.LogSoftmax() )
def ChunkedFeedForward(d_model, d_ff, dropout, activation, chunk_size, mode): """Chunked feed-forward block with layer normalization at start.""" ff = FeedForward(d_model, d_ff, dropout, activation, mode) if chunk_size < 1: return ff def reshape_to_chunks(x): batch_times_length = x.shape[0] * x.shape[1] assert batch_times_length % chunk_size == 0 n_chunks = batch_times_length // chunk_size return np.reshape(x, [n_chunks, 1, chunk_size] + list(x.shape[2:])) return [ tl.Dup(), # Just to have shape for later after scan. tl.Fn(reshape_to_chunks, n_out=1), tl.Scan(tl.Serial(ff), axis=0, n_carry=0, remat=True), tl.Fn(lambda x, y: np.reshape(x, y.shape)) ]
def RNNLM(vocab_size, d_model=512, n_layers=2, rnn_cell=tl.LSTMCell, rnn_cell_d_state_multiplier=2, dropout=0.1, mode='train'): """Returns an RNN language model. This model performs autoregressive language modeling: - input: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(vocab_size)`, and `0` values mark padding positions. - output: rank 3 tensor representing a batch of log-probability distributions for each sequence position over possible token IDs; shape is (batch_size, sequence_length, `vocab_size`). Args: vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. d_model: Embedding depth throughout the model. n_layers: Number of RNN layers. rnn_cell: Type of RNN cell; must be a subclass of `Layer`. rnn_cell_d_state_multiplier: Multiplier for feature depth of RNN cell state. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout. mode: If `'predict'`, use fast inference; if `'train'` apply dropout. Returns: An RNN language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ if n_layers != 2: # TODO(jonni): Remove n_layers arg, if it can't vary? raise ValueError(f'Number of layers must be set to 2; instead got' f' {n_layers}.') def MultiRNNCell(): """Multi-layer RNN cell.""" return tl.Serial( tl.Parallel([], tl.Split(n_items=n_layers)), tl.SerialWithSideOutputs( [rnn_cell(n_units=d_model) for _ in range(n_layers)]), tl.Parallel([], tl.Concatenate(n_items=n_layers)) ) zero_state = tl.MakeZeroState( # pylint: disable=no-value-for-parameter depth_multiplier=n_layers * rnn_cell_d_state_multiplier ) return tl.Serial( tl.ShiftRight(mode=mode), tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, mode=mode), tl.Branch([], zero_state), tl.Scan(MultiRNNCell(), axis=1), tl.Select([0], n_in=2), # Drop RNN state. tl.Dense(vocab_size), )