コード例 #1
0
ファイル: rse.py プロジェクト: yliu45/trax
def BenesBlock(d_model, dropout, mode):
  def bit_sequence(inputs):
    seq_length = inputs.shape[1]
    n_bits = np.int32(np.log(seq_length - 1) / np.log(2.0)) + 1
    return jnp.arange(0, n_bits)
  return tl.Serial(
      tl.Dup(),
      tl.Fn('BitSeq', bit_sequence, n_out=1),
      tl.Scan(_ForwardStep(d_model, dropout, mode)),
      tl.Scan(_BackwardStep(d_model, dropout, mode)),
      tl.Select([1]),
  )
コード例 #2
0
  def test_multi_input(self, backend):
    def _MultiInputFn():  # pylint: disable=invalid-name
      def f(a, b, carry):
        return a + b, b, carry + 1
      return tl.Fn('MultiInputFn', f, n_out=2)

    with fastmath.use_backend(backend):
      layer = tl.Scan(_MultiInputFn(), axis=1)
      xs = [
          np.array([[0, 1, 2],
                    [0, 10, 20]]),
          np.array([[4, 5, 6],
                    [40, 50, 60]]),
          np.array([9000,
                    8000])
      ]
      ys = layer(xs)
      self.assertEqual(as_list(ys),
                       [[[4, 6, 8],
                         [40, 60, 80]],
                        [[4, 5, 6],
                         [40, 50, 60]],
                        [9003,
                         8003]
                       ])
コード例 #3
0
ファイル: combinators_test.py プロジェクト: srush/trax
    def test_no_carry(self):
        def _AddOne():  # pylint: disable=invalid-name
            return tl.Fn('AddOne', lambda x: x + 1)

        layer = tl.Scan(_AddOne(), n_carry=0)
        x = np.array([[1, 3, 7], [10, 30, 70]])
        y = layer(x)
        self.assertEqual(as_list(y), [[2, 4, 8], [11, 31, 71]])
コード例 #4
0
 def test_predict(self, backend):
     with fastmath.use_backend(backend):
         layer = tl.Scan(self._AddWithCarry(), axis=1, mode='predict')
         xs = [np.array([[0, 1, 2]]), np.array([90])]
         ys = layer(xs)
         self.assertEqual(as_list(ys), [[[90, 91, 93]], [93]])
         xs = [np.array([[3, 4]]), np.array([90])]
         ys = layer(xs)
         self.assertEqual(as_list(ys), [[[96, 100]], [100]])
コード例 #5
0
ファイル: combinators_test.py プロジェクト: srush/trax
 def test_axis_1(self):
     layer = tl.Scan(self._AddWithCarry(), axis=1)
     xs = [
         np.array([[0, 1, 2, 3], [0, 10, 20, 30], [0, 100, 200, 300]]),
         np.array([9000, 8000, 7000])
     ]
     ys = layer(xs)
     self.assertEqual(as_list(ys),
                      [[[9000, 9001, 9003, 9006], [8000, 8010, 8030, 8060],
                        [7000, 7100, 7300, 7600]], [9006, 8060, 7600]])
コード例 #6
0
ファイル: combinators_test.py プロジェクト: srush/trax
 def test_default_axis(self):
     layer = tl.Scan(self._AddWithCarry())
     xs = [
         np.array([[0, 1, 2, 3], [0, 10, 20, 30], [0, 100, 200, 300]]),
         np.array([9000, 8000, 7000, 6000])
     ]
     ys = layer(xs)
     self.assertEqual(
         as_list(ys),
         [[[9000, 8001, 7002, 6003], [9000, 8011, 7022, 6033],
           [9000, 8111, 7222, 6333]], [9000, 8111, 7222, 6333]])
コード例 #7
0
def RNNLM(vocab_size,
          d_model=512,
          n_layers=2,
          rnn_cell=tl.LSTMCell,
          rnn_cell_d_state_multiplier=2,
          dropout=0.1,
          mode='train'):
  """Returns an RNN language model.

  The input to the model is a tensor of tokens (ints).

  Args:
    vocab_size: int: vocab size
    d_model: int:  depth of embedding (n_units in the RNN cell)
    n_layers: int: number of RNN layers
    rnn_cell: the RNN cell
    rnn_cell_d_state_multiplier: how many times is RNN cell state larger
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference

  Returns:
    An RNN language model as a layer that maps from a tensor of tokens
    to activations over a vocab set.
  """
  def MultiRNNCell():
    """Multi-layer RNN cell."""
    assert n_layers == 2
    return tl.Serial(
        tl.Parallel([], tl.Split(n_items=n_layers)),
        tl.SerialWithSideOutputs(
            [rnn_cell(n_units=d_model) for _ in range(n_layers)]),
        tl.Parallel([], tl.Concatenate(n_items=n_layers))
    )

  zero_state = tl.MakeZeroState(  # pylint: disable=no-value-for-parameter
      depth_multiplier=n_layers * rnn_cell_d_state_multiplier
  )

  return tl.Serial(
      tl.ShiftRight(mode=mode),
      tl.Embedding(d_model, vocab_size),
      tl.Dropout(rate=dropout, mode=mode),
      tl.Branch([], zero_state),
      tl.Scan(MultiRNNCell(), axis=1),
      tl.Select([0], n_in=2),  # Drop RNN state.
      tl.Dense(vocab_size),
      tl.LogSoftmax()
  )
コード例 #8
0
def ChunkedFeedForward(d_model, d_ff, dropout, activation, chunk_size, mode):
  """Chunked feed-forward block with layer normalization at start."""
  ff = FeedForward(d_model, d_ff, dropout, activation, mode)
  if chunk_size < 1:
    return ff
  def reshape_to_chunks(x):
    batch_times_length = x.shape[0] * x.shape[1]
    assert batch_times_length % chunk_size == 0
    n_chunks = batch_times_length // chunk_size
    return np.reshape(x, [n_chunks, 1, chunk_size] + list(x.shape[2:]))
  return [
      tl.Dup(),  # Just to have shape for later after scan.
      tl.Fn(reshape_to_chunks, n_out=1),
      tl.Scan(tl.Serial(ff), axis=0, n_carry=0, remat=True),
      tl.Fn(lambda x, y: np.reshape(x, y.shape))
  ]
コード例 #9
0
ファイル: rnn.py プロジェクト: yaoshuyin/trax
def RNNLM(vocab_size,
          d_model=512,
          n_layers=2,
          rnn_cell=tl.LSTMCell,
          rnn_cell_d_state_multiplier=2,
          dropout=0.1,
          mode='train'):
  """Returns an RNN language model.

  This model performs autoregressive language modeling:

    - input: rank 2 tensor representing a batch of text strings via token IDs
      plus padding markers; shape is (batch_size, sequence_length). The tensor
      elements are integers in `range(vocab_size)`, and `0` values mark padding
      positions.

    - output: rank 3 tensor representing a batch of log-probability
      distributions for each sequence position over possible token IDs;
      shape is (batch_size, sequence_length, `vocab_size`).

  Args:
    vocab_size: Input vocabulary size -- each element of the input tensor
        should be an integer in `range(vocab_size)`. These integers typically
        represent token IDs from a vocabulary-based tokenizer.
    d_model: Embedding depth throughout the model.
    n_layers: Number of RNN layers.
    rnn_cell: Type of RNN cell; must be a subclass of `Layer`.
    rnn_cell_d_state_multiplier: Multiplier for feature depth of RNN cell
        state.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout.
    mode: If `'predict'`, use fast inference; if `'train'` apply dropout.

  Returns:
    An RNN language model as a layer that maps from a tensor of tokens
    to activations over a vocab set.
  """

  if n_layers != 2:  # TODO(jonni): Remove n_layers arg, if it can't vary?
    raise ValueError(f'Number of layers must be set to 2; instead got'
                     f' {n_layers}.')

  def MultiRNNCell():
    """Multi-layer RNN cell."""
    return tl.Serial(
        tl.Parallel([], tl.Split(n_items=n_layers)),
        tl.SerialWithSideOutputs(
            [rnn_cell(n_units=d_model) for _ in range(n_layers)]),
        tl.Parallel([], tl.Concatenate(n_items=n_layers))
    )

  zero_state = tl.MakeZeroState(  # pylint: disable=no-value-for-parameter
      depth_multiplier=n_layers * rnn_cell_d_state_multiplier
  )

  return tl.Serial(
      tl.ShiftRight(mode=mode),
      tl.Embedding(vocab_size, d_model),
      tl.Dropout(rate=dropout, mode=mode),
      tl.Branch([], zero_state),
      tl.Scan(MultiRNNCell(), axis=1),
      tl.Select([0], n_in=2),  # Drop RNN state.
      tl.Dense(vocab_size),
  )