Example #1
0
def SRU(n_units, activation=None):
    """SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.

  As defined in the paper:
  (1) y_t = W x_t (+ B optionally, which we do)
  (2) f_t = sigmoid(Wf x_t + bf)
  (3) r_t = sigmoid(Wr x_t + br)
  (4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t
  (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t

  We assume the input is of shape [batch, length, depth] and recurrence
  happens on the length dimension. This returns a single layer. It's best
  to use at least 2, they say in the paper, except inside a Transformer.

  Args:
    n_units: output depth of the SRU layer.
    activation: Optional activation function.

  Returns:
    The SRU layer.
  """
    sigmoid_activation = activation_fns.Sigmoid()
    # pylint: disable=no-value-for-parameter
    return cb.Serial(  # x
        cb.Branch(core.Dense(3 * n_units), []),  # r_f_y, x
        cb.Split(n_items=3),  # r, f, y, x
        cb.Parallel(sigmoid_activation, sigmoid_activation),  # r, f, y, x
        base.Fn(lambda r, f, y: (y * (1.0 - f), f, r)),  # y * (1 - f), f, r, x
        cb.Parallel([], [], cb.Branch(MakeZeroState(), [])),
        cb.Scan(InnerSRUCell(), axis=1),
        cb.Select([0], n_in=2),  # act(c), r, x
        activation or [],
        base.Fn(lambda c, r, x: c * r + x * (1 - r)))
Example #2
0
def LSTM(n_units):
    """LSTM running on axis 1."""
    zero_state = MakeZeroState(depth_multiplier=2)  # pylint: disable=no-value-for-parameter
    return cb.Serial(
        cb.Branch([], zero_state),
        cb.Scan(LSTMCell(n_units=n_units), axis=1),
        cb.Select([0], n_in=2)  # Drop RNN state.
    )
Example #3
0
def LSTM(n_units, mode='train', return_state=False, initial_state=False):
    """LSTM running on axis 1.

  Args:
    n_units: `n_units` for the `LSTMCell`.
    mode: if 'predict' then we save the previous state for one-by-one inference.
    return_state: Boolean. Whether to return the latest status in addition to
      the output. Default: False.
    initial_state: Boolean. If the state RNN (c, h) is to be obtained from the
      stack. Default: False.

  Returns:
    A LSTM layer.
  """

    if not initial_state:
        zero_state = MakeZeroState(depth_multiplier=2)  # pylint: disable=no-value-for-parameter
        if return_state:
            return cb.Serial(cb.Branch([], zero_state),
                             cb.Scan(LSTMCell(n_units=n_units),
                                     axis=1,
                                     mode=mode),
                             name=f'LSTM_{n_units}',
                             sublayers_to_print=[])
        else:
            return cb.Serial(
                cb.Branch([], zero_state),  # fill state RNN with zero.
                cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode),
                cb.Select([0], n_in=2),  # Drop RNN state.
                # Set the name to LSTM and don't print sublayers.
                name=f'LSTM_{n_units}',
                sublayers_to_print=[])
    else:
        if return_state:
            return cb.Serial(cb.Scan(LSTMCell(n_units=n_units),
                                     axis=1,
                                     mode=mode),
                             name=f'LSTM_{n_units}',
                             sublayers_to_print=[])
        else:
            return cb.Serial(
                cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode),
                cb.Select([0], n_in=2),  # Drop RNN state.
                name=f'LSTM_{n_units}',
                sublayers_to_print=[])
Example #4
0
def LSTM(n_units, mode='train'):
  """LSTM running on axis 1."""
  zero_state = MakeZeroState(depth_multiplier=2)  # pylint: disable=no-value-for-parameter
  return cb.Serial(
      cb.Branch([], zero_state),
      cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode),
      cb.Select([0], n_in=2),  # Drop RNN state.
      # Set the name to LSTM and don't print sublayers.
      name=f'LSTM_{n_units}', sublayers_to_print=[]
  )
Example #5
0
def GRU(n_units):
  """GRU running on axis 1."""
  zero_state = MakeZeroState(depth_multiplier=1)  # pylint: disable=no-value-for-parameter
  return cb.Serial(
      cb.Branch([], zero_state),
      cb.Scan(GRUCell(n_units=n_units), axis=1),
      cb.Select([0], n_in=2),  # Drop RNN state.
      # Set the name to GRU and don't print sublayers.
      name=f'GRU_{n_units}', sublayers_to_print=[]
  )
Example #6
0
 def test_scan_axis1(self):
   @base.layer(n_in=2, n_out=2)
   def add(x, **unused_kwargs):
     res = x[0] + x[1]
     return res, res
   scan = cb.Scan(add(), axis=1)  # pylint: disable=no-value-for-parameter
   input_signature = (ShapeDtype((3, 2, 7)), ShapeDtype((3, 7)))
   expected_shape = ((3, 2, 7), (3, 7))
   output_shape = base.check_shape_agreement(scan, input_signature)
   self.assertEqual(output_shape, expected_shape)
Example #7
0
 def test_scan_multiinput(self):
   @base.layer(n_in=3, n_out=2)
   def foo(x, **unused_kwargs):
     a, b, carry = x
     return a + b, b, carry + 1
   scan = cb.Scan(foo(), axis=1)  # pylint: disable=no-value-for-parameter
   input_signature = (ShapeDtype((3, 2, 7)), ShapeDtype((3, 2, 7)),
                      ShapeDtype((3, 7)))
   expected_shape = ((3, 2, 7), (3, 2, 7), (3, 7))
   output_shape = base.check_shape_agreement(scan, input_signature)
   self.assertEqual(output_shape, expected_shape)
Example #8
0
  def test_scan_nocarry(self):
    def addone():  # pylint: disable=invalid-name
      return base.Fn('addone', lambda x: x + 1)

    scan_layer = cb.Scan(addone(), n_carry=0)
    input_signature = ShapeDtype((3, 2, 7))
    expected_shape = (3, 2, 7)
    output_shape = base.check_shape_agreement(scan_layer, input_signature)
    self.assertEqual(output_shape, expected_shape)
    inp = np.array([1, 2, 3])
    o = scan_layer(inp)
    self.assertEqual([int(x) for x in o], [2, 3, 4])
Example #9
0
  def test_scan_multiinput(self):
    def foo():  # pylint: disable=invalid-name
      def f(a, b, carry):
        return a + b, b, carry + 1
      return base.Fn('foo', f, n_out=2)

    scan = cb.Scan(foo(), axis=1)
    input_signature = (ShapeDtype((3, 2, 7)), ShapeDtype((3, 2, 7)),
                       ShapeDtype((3, 7)))
    expected_shape = ((3, 2, 7), (3, 2, 7), (3, 7))
    output_shape = base.check_shape_agreement(scan, input_signature)
    self.assertEqual(output_shape, expected_shape)
Example #10
0
  def test_scan_axis1(self):
    def add():  # pylint: disable=invalid-name
      def f(x, carry):
        res = x + carry
        return res, res  # output and carry are the same
      return base.Fn('add', f, n_out=2)

    scan = cb.Scan(add(), axis=1)
    input_signature = (ShapeDtype((3, 2, 7)), ShapeDtype((3, 7)))
    expected_shape = ((3, 2, 7), (3, 7))
    output_shape = base.check_shape_agreement(scan, input_signature)
    self.assertEqual(output_shape, expected_shape)
Example #11
0
    def test_scan_nocarry(self):
        @base.layer(n_in=1, n_out=1)
        def addone(x, **unused_kwargs):
            return x + 1

        scan_layer = cb.Scan(addone(), n_carry=0)  # pylint: disable=no-value-for-parameter
        input_signature = ShapeDtype((3, 2, 7))
        expected_shape = (3, 2, 7)
        output_shape = base.check_shape_agreement(scan_layer, input_signature)
        self.assertEqual(output_shape, expected_shape)
        inp = np.array([1, 2, 3])
        o = scan_layer(inp)
        self.assertEqual([int(x) for x in o], [2, 3, 4])
Example #12
0
 def test_scan_basic(self):
   @base.layer(n_in=2, n_out=2)
   def add(x, **unused_kwargs):
     res = x[0] + x[1]
     return res, res
   scan_layer = cb.Scan(add())  # pylint: disable=no-value-for-parameter
   input_signature = (ShapeDtype((3, 2, 7)), ShapeDtype((2, 7)))
   expected_shape = ((3, 2, 7), (2, 7))
   output_shape = base.check_shape_agreement(scan_layer, input_signature)
   self.assertEqual(output_shape, expected_shape)
   inp = (np.array([1, 2, 3]), np.array(0))
   o, v = scan_layer(inp)
   self.assertEqual(int(v), 6)
   self.assertEqual([int(x) for x in o], [1, 3, 6])
Example #13
0
def ScanSRUCell(mode, monkey_patched_mask=None):
    """The inner (non-parallel) computation of an SRU."""
    if monkey_patched_mask is None:
        return cb.Scan(InnerSRUCell(), axis=1, mode=mode)

    # This is necessary for Terraformer model. See comments there.
    # The mask will only be used in Terraformer in predict mode.
    assert mode == 'predict'

    def update_mask(mask, x_times_one_minus_f):  # pylint: disable=invalid-name
        initial = jnp.ones(x_times_one_minus_f.shape[:2], dtype=jnp.float32)
        if initial.shape[1] > 1:
            updated_mask = fastmath.dynamic_update_slice_in_dim(initial != 0,
                                                                mask != 0,
                                                                1,
                                                                axis=1)
        else:
            updated_mask = initial
        return updated_mask, x_times_one_minus_f

    def masked_inner_sru_cell(
            cur_mask,
            cur_x_times_one_minus_f,
            cur_f,  # pylint: disable=invalid-name
            cur_state):
        res = ((cur_f * cur_state + cur_x_times_one_minus_f) * cur_mask +
               (1 - cur_mask) * cur_state)
        return res, res

    return cb.Serial(
        monkey_patched_mask.get_layer(),
        base.Fn('update_mask', update_mask, n_out=2),
        cb.Scan(base.Fn('MaskedInnerSRUCell', masked_inner_sru_cell, n_out=2),
                axis=1,
                mode=mode),
    )
Example #14
0
  def test_scan_basic(self):
    def add():  # pylint: disable=invalid-name
      def f(x, carry):
        res = x + carry
        return res, res  # output and carry are the same
      return base.Fn('add', f, n_out=2)

    scan_layer = cb.Scan(add())
    input_signature = (ShapeDtype((3, 2, 7)), ShapeDtype((2, 7)))
    expected_shape = ((3, 2, 7), (2, 7))
    output_shape = base.check_shape_agreement(scan_layer, input_signature)
    self.assertEqual(output_shape, expected_shape)
    inp = (np.array([1, 2, 3]), np.array(0))
    o, v = scan_layer(inp)
    self.assertEqual(int(v), 6)
    self.assertEqual([int(x) for x in o], [1, 3, 6])
Example #15
0
File: rnn.py Project: jalammar/trax
def SRU(n_units, activation=None):
    r"""SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.

  As defined in the paper:

  .. math::
    y_t &= W x_t + B \quad \hbox{(include $B$ optionally)} \\
    f_t &= \sigma(Wf x_t + bf) \\
    r_t &= \sigma(Wr x_t + br) \\
    c_t &= f_t \times c_{t-1} + (1 - f_t) \times y_t \\
    h_t &= r_t \times \hbox{activation}(c_t) + (1 - r_t) \times x_t

  We assume the input is of shape [batch, length, depth] and recurrence
  happens on the length dimension. This returns a single layer. It's best
  to use at least 2, they say in the paper, except inside a Transformer.

  Args:
    n_units: output depth of the SRU layer.
    activation: Optional activation function.

  Returns:
    The SRU layer.
  """
    sigmoid_activation = activation_fns.Sigmoid()
    return cb.Serial(  # x
        cb.Branch(core.Dense(3 * n_units), []),  # r_f_y, x
        cb.Split(n_items=3),  # r, f, y, x
        cb.Parallel(sigmoid_activation, sigmoid_activation),  # r, f, y, x
        base.Fn(
            '',
            lambda r, f, y: (y * (1.0 - f), f, r),  # y * (1 - f), f, r, x
            n_out=3),
        cb.Parallel([], [], cb.Branch(MakeZeroState(), [])),
        cb.Scan(InnerSRUCell(), axis=1),
        cb.Select([0], n_in=2),  # act(c), r, x
        activation or [],
        base.Fn('FinalSRUGate', lambda c, r, x: c * r + x * (1 - r) *
                (3**0.5)),
        # Set the name to SRU and don't print sublayers.
        name=f'SRU_{n_units}',
        sublayers_to_print=[])
Example #16
0
def SRU(n_units, activation=None, rescale=False, highway_bias=0):
    """SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.

  As defined in the paper:
  (1) y_t = W x_t (+ B optionally, which we do)
  (2) f_t = sigmoid(Wf x_t + bf)
  (3) r_t = sigmoid(Wr x_t + br)
  (4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t
  (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t * alpha

  We assume the input is of shape [batch, length, depth] and recurrence
  happens on the length dimension. This returns a single layer. It's best
  to use at least 2, they say in the paper, except inside a Transformer.

  Args:
    n_units: output depth of the SRU layer.
    activation: Optional activation function.
    rescale: To offset the problem of the gradient vanishing in the h_t as a result
    of light recurrence and highway computation for deeper layers, a scaling correction
    alpha is applied as follows: (1 + exp(highway_bias) * 2)**0.5 ref: https://arxiv.org/abs/1709.02755,
    page 4, section 3.2 Initialization.
    highway_bias: intial bias of highway gates
  Returns:
    The SRU layer.
  """
    # pylint: disable=no-value-for-parameter
    return cb.Serial(  # x
        cb.Branch(core.Dense(3 * n_units), []),  # r_f_y, x
        cb.Split(n_items=3),  # r, f, y, x
        cb.Parallel(core.Sigmoid(), core.Sigmoid()),  # r, f, y, x
        base.Fn(lambda r, f, y: (y * (1.0 - f), f, r)),  # y * (1 - f), f, r, x
        cb.Parallel([], [], cb.Branch(MakeZeroState(), [])),
        cb.Scan(InnerSRUCell(), axis=1),
        cb.Select([0], n_in=2),  # act(c), r, x
        activation or [],
        base.Fn(lambda c, r, x: c * r + x * (1 - r) *
                ((1 + np.exp(highway_bias) * 2)**0.5 if rescale else 1)))