def SRU(n_units, activation=None): """SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755. As defined in the paper: (1) y_t = W x_t (+ B optionally, which we do) (2) f_t = sigmoid(Wf x_t + bf) (3) r_t = sigmoid(Wr x_t + br) (4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It's best to use at least 2, they say in the paper, except inside a Transformer. Args: n_units: output depth of the SRU layer. activation: Optional activation function. Returns: The SRU layer. """ sigmoid_activation = activation_fns.Sigmoid() # pylint: disable=no-value-for-parameter return cb.Serial( # x cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x cb.Split(n_items=3), # r, f, y, x cb.Parallel(sigmoid_activation, sigmoid_activation), # r, f, y, x base.Fn(lambda r, f, y: (y * (1.0 - f), f, r)), # y * (1 - f), f, r, x cb.Parallel([], [], cb.Branch(MakeZeroState(), [])), cb.Scan(InnerSRUCell(), axis=1), cb.Select([0], n_in=2), # act(c), r, x activation or [], base.Fn(lambda c, r, x: c * r + x * (1 - r)))
def LSTM(n_units): """LSTM running on axis 1.""" zero_state = MakeZeroState(depth_multiplier=2) # pylint: disable=no-value-for-parameter return cb.Serial( cb.Branch([], zero_state), cb.Scan(LSTMCell(n_units=n_units), axis=1), cb.Select([0], n_in=2) # Drop RNN state. )
def LSTM(n_units, mode='train', return_state=False, initial_state=False): """LSTM running on axis 1. Args: n_units: `n_units` for the `LSTMCell`. mode: if 'predict' then we save the previous state for one-by-one inference. return_state: Boolean. Whether to return the latest status in addition to the output. Default: False. initial_state: Boolean. If the state RNN (c, h) is to be obtained from the stack. Default: False. Returns: A LSTM layer. """ if not initial_state: zero_state = MakeZeroState(depth_multiplier=2) # pylint: disable=no-value-for-parameter if return_state: return cb.Serial(cb.Branch([], zero_state), cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), name=f'LSTM_{n_units}', sublayers_to_print=[]) else: return cb.Serial( cb.Branch([], zero_state), # fill state RNN with zero. cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), cb.Select([0], n_in=2), # Drop RNN state. # Set the name to LSTM and don't print sublayers. name=f'LSTM_{n_units}', sublayers_to_print=[]) else: if return_state: return cb.Serial(cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), name=f'LSTM_{n_units}', sublayers_to_print=[]) else: return cb.Serial( cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), cb.Select([0], n_in=2), # Drop RNN state. name=f'LSTM_{n_units}', sublayers_to_print=[])
def LSTM(n_units, mode='train'): """LSTM running on axis 1.""" zero_state = MakeZeroState(depth_multiplier=2) # pylint: disable=no-value-for-parameter return cb.Serial( cb.Branch([], zero_state), cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode), cb.Select([0], n_in=2), # Drop RNN state. # Set the name to LSTM and don't print sublayers. name=f'LSTM_{n_units}', sublayers_to_print=[] )
def GRU(n_units): """GRU running on axis 1.""" zero_state = MakeZeroState(depth_multiplier=1) # pylint: disable=no-value-for-parameter return cb.Serial( cb.Branch([], zero_state), cb.Scan(GRUCell(n_units=n_units), axis=1), cb.Select([0], n_in=2), # Drop RNN state. # Set the name to GRU and don't print sublayers. name=f'GRU_{n_units}', sublayers_to_print=[] )
def test_scan_axis1(self): @base.layer(n_in=2, n_out=2) def add(x, **unused_kwargs): res = x[0] + x[1] return res, res scan = cb.Scan(add(), axis=1) # pylint: disable=no-value-for-parameter input_signature = (ShapeDtype((3, 2, 7)), ShapeDtype((3, 7))) expected_shape = ((3, 2, 7), (3, 7)) output_shape = base.check_shape_agreement(scan, input_signature) self.assertEqual(output_shape, expected_shape)
def test_scan_multiinput(self): @base.layer(n_in=3, n_out=2) def foo(x, **unused_kwargs): a, b, carry = x return a + b, b, carry + 1 scan = cb.Scan(foo(), axis=1) # pylint: disable=no-value-for-parameter input_signature = (ShapeDtype((3, 2, 7)), ShapeDtype((3, 2, 7)), ShapeDtype((3, 7))) expected_shape = ((3, 2, 7), (3, 2, 7), (3, 7)) output_shape = base.check_shape_agreement(scan, input_signature) self.assertEqual(output_shape, expected_shape)
def test_scan_nocarry(self): def addone(): # pylint: disable=invalid-name return base.Fn('addone', lambda x: x + 1) scan_layer = cb.Scan(addone(), n_carry=0) input_signature = ShapeDtype((3, 2, 7)) expected_shape = (3, 2, 7) output_shape = base.check_shape_agreement(scan_layer, input_signature) self.assertEqual(output_shape, expected_shape) inp = np.array([1, 2, 3]) o = scan_layer(inp) self.assertEqual([int(x) for x in o], [2, 3, 4])
def test_scan_multiinput(self): def foo(): # pylint: disable=invalid-name def f(a, b, carry): return a + b, b, carry + 1 return base.Fn('foo', f, n_out=2) scan = cb.Scan(foo(), axis=1) input_signature = (ShapeDtype((3, 2, 7)), ShapeDtype((3, 2, 7)), ShapeDtype((3, 7))) expected_shape = ((3, 2, 7), (3, 2, 7), (3, 7)) output_shape = base.check_shape_agreement(scan, input_signature) self.assertEqual(output_shape, expected_shape)
def test_scan_axis1(self): def add(): # pylint: disable=invalid-name def f(x, carry): res = x + carry return res, res # output and carry are the same return base.Fn('add', f, n_out=2) scan = cb.Scan(add(), axis=1) input_signature = (ShapeDtype((3, 2, 7)), ShapeDtype((3, 7))) expected_shape = ((3, 2, 7), (3, 7)) output_shape = base.check_shape_agreement(scan, input_signature) self.assertEqual(output_shape, expected_shape)
def test_scan_nocarry(self): @base.layer(n_in=1, n_out=1) def addone(x, **unused_kwargs): return x + 1 scan_layer = cb.Scan(addone(), n_carry=0) # pylint: disable=no-value-for-parameter input_signature = ShapeDtype((3, 2, 7)) expected_shape = (3, 2, 7) output_shape = base.check_shape_agreement(scan_layer, input_signature) self.assertEqual(output_shape, expected_shape) inp = np.array([1, 2, 3]) o = scan_layer(inp) self.assertEqual([int(x) for x in o], [2, 3, 4])
def test_scan_basic(self): @base.layer(n_in=2, n_out=2) def add(x, **unused_kwargs): res = x[0] + x[1] return res, res scan_layer = cb.Scan(add()) # pylint: disable=no-value-for-parameter input_signature = (ShapeDtype((3, 2, 7)), ShapeDtype((2, 7))) expected_shape = ((3, 2, 7), (2, 7)) output_shape = base.check_shape_agreement(scan_layer, input_signature) self.assertEqual(output_shape, expected_shape) inp = (np.array([1, 2, 3]), np.array(0)) o, v = scan_layer(inp) self.assertEqual(int(v), 6) self.assertEqual([int(x) for x in o], [1, 3, 6])
def ScanSRUCell(mode, monkey_patched_mask=None): """The inner (non-parallel) computation of an SRU.""" if monkey_patched_mask is None: return cb.Scan(InnerSRUCell(), axis=1, mode=mode) # This is necessary for Terraformer model. See comments there. # The mask will only be used in Terraformer in predict mode. assert mode == 'predict' def update_mask(mask, x_times_one_minus_f): # pylint: disable=invalid-name initial = jnp.ones(x_times_one_minus_f.shape[:2], dtype=jnp.float32) if initial.shape[1] > 1: updated_mask = fastmath.dynamic_update_slice_in_dim(initial != 0, mask != 0, 1, axis=1) else: updated_mask = initial return updated_mask, x_times_one_minus_f def masked_inner_sru_cell( cur_mask, cur_x_times_one_minus_f, cur_f, # pylint: disable=invalid-name cur_state): res = ((cur_f * cur_state + cur_x_times_one_minus_f) * cur_mask + (1 - cur_mask) * cur_state) return res, res return cb.Serial( monkey_patched_mask.get_layer(), base.Fn('update_mask', update_mask, n_out=2), cb.Scan(base.Fn('MaskedInnerSRUCell', masked_inner_sru_cell, n_out=2), axis=1, mode=mode), )
def test_scan_basic(self): def add(): # pylint: disable=invalid-name def f(x, carry): res = x + carry return res, res # output and carry are the same return base.Fn('add', f, n_out=2) scan_layer = cb.Scan(add()) input_signature = (ShapeDtype((3, 2, 7)), ShapeDtype((2, 7))) expected_shape = ((3, 2, 7), (2, 7)) output_shape = base.check_shape_agreement(scan_layer, input_signature) self.assertEqual(output_shape, expected_shape) inp = (np.array([1, 2, 3]), np.array(0)) o, v = scan_layer(inp) self.assertEqual(int(v), 6) self.assertEqual([int(x) for x in o], [1, 3, 6])
def SRU(n_units, activation=None): r"""SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755. As defined in the paper: .. math:: y_t &= W x_t + B \quad \hbox{(include $B$ optionally)} \\ f_t &= \sigma(Wf x_t + bf) \\ r_t &= \sigma(Wr x_t + br) \\ c_t &= f_t \times c_{t-1} + (1 - f_t) \times y_t \\ h_t &= r_t \times \hbox{activation}(c_t) + (1 - r_t) \times x_t We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It's best to use at least 2, they say in the paper, except inside a Transformer. Args: n_units: output depth of the SRU layer. activation: Optional activation function. Returns: The SRU layer. """ sigmoid_activation = activation_fns.Sigmoid() return cb.Serial( # x cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x cb.Split(n_items=3), # r, f, y, x cb.Parallel(sigmoid_activation, sigmoid_activation), # r, f, y, x base.Fn( '', lambda r, f, y: (y * (1.0 - f), f, r), # y * (1 - f), f, r, x n_out=3), cb.Parallel([], [], cb.Branch(MakeZeroState(), [])), cb.Scan(InnerSRUCell(), axis=1), cb.Select([0], n_in=2), # act(c), r, x activation or [], base.Fn('FinalSRUGate', lambda c, r, x: c * r + x * (1 - r) * (3**0.5)), # Set the name to SRU and don't print sublayers. name=f'SRU_{n_units}', sublayers_to_print=[])
def SRU(n_units, activation=None, rescale=False, highway_bias=0): """SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755. As defined in the paper: (1) y_t = W x_t (+ B optionally, which we do) (2) f_t = sigmoid(Wf x_t + bf) (3) r_t = sigmoid(Wr x_t + br) (4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t * alpha We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It's best to use at least 2, they say in the paper, except inside a Transformer. Args: n_units: output depth of the SRU layer. activation: Optional activation function. rescale: To offset the problem of the gradient vanishing in the h_t as a result of light recurrence and highway computation for deeper layers, a scaling correction alpha is applied as follows: (1 + exp(highway_bias) * 2)**0.5 ref: https://arxiv.org/abs/1709.02755, page 4, section 3.2 Initialization. highway_bias: intial bias of highway gates Returns: The SRU layer. """ # pylint: disable=no-value-for-parameter return cb.Serial( # x cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x cb.Split(n_items=3), # r, f, y, x cb.Parallel(core.Sigmoid(), core.Sigmoid()), # r, f, y, x base.Fn(lambda r, f, y: (y * (1.0 - f), f, r)), # y * (1 - f), f, r, x cb.Parallel([], [], cb.Branch(MakeZeroState(), [])), cb.Scan(InnerSRUCell(), axis=1), cb.Select([0], n_in=2), # act(c), r, x activation or [], base.Fn(lambda c, r, x: c * r + x * (1 - r) * ((1 + np.exp(highway_bias) * 2)**0.5 if rescale else 1)))