def SRU(n_units, activation=None): """SRU layer as in https://arxiv.org/abs/1709.02755. As defined in the paper: (1) y_t = W x_t (+ B optionally, which we do) (2) f_t = sigmoid(Wf x_t + bf) (3) r_t = sigmoid(Wr x_t + br) (4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It's best to use at least 2, they say in the paper, except inside a Transformer. Args: n_units: output depth of the SRU layer. activation: Optional activation function. Returns: The SRU layer. """ activation = activation or [] return cb.Serial( cb.Dup(), # x, x core.Dense(3 * n_units), cb.Split(n_items=3), # r, f, y, x cb.Parallel(core.Sigmoid(), core.Sigmoid()), # r, f, y, x cb.Fn(lambda r, f, y: (y * (1.0 - f), f, r)), # y * (1 - f), f, r, x cb.Parallel([], [], [cb.Dup(), MakeZeroState()]), # pylint: disable=no-value-for-parameter cb.Scan(InnerSRUCell(), axis=1), # pylint: disable=no-value-for-parameter cb.Parallel(activation, cb.Drop()), # act(c), r, x cb.Fn(lambda c, r, x: c * r + x * (1 - r)))
def test_fn_layer_varargs_n_in(self): with self.assertRaisesRegexp(ValueError, 'variable arg'): cb.Fn(lambda *args: args[0]) # Check that varargs work when n_in is set. id_layer = cb.Fn(lambda *args: args[0], n_in=1) input_signature = ShapeDtype((2, 7)) expected_shape = (2, 7) output_shape = base.check_shape_agreement(id_layer, input_signature) self.assertEqual(output_shape, expected_shape)
def PositionalEmbeddings(d_feature, separate_cls, total_kv_pooling): """Positional embedding for relative attention. Returns a layer that based on queries, keys and accumulated pool size of keys/values until this layer calculates sinusoidal positional embeddings for relative attention calculations. Args: d_feature: Depth/dimensionality of feature embedding. separate_cls: True/False if we separate_cls in calculations. total_kv_pooling: Accumulated pool size of keys/values until this layer. Returns: Positional embedding. """ def PositionsVectors(queries, keys): is_funnel_layer = queries.shape != keys.shape keys_len, queries_len = keys.shape[1], queries.shape[1] current_pooling_ratio = keys_len / queries_len # Special case of upsampling if is_funnel_layer and current_pooling_ratio < 1: # We should not be doing standard upsampling when we use separate_cls # Cls token is being used for classification assert not separate_cls assert (total_kv_pooling * keys_len) % queries_len == 0 multiplier = ((total_kv_pooling * keys_len) // queries_len) positions = jnp.arange(-queries_len + 1, queries_len, 1.0) * multiplier else: positions = jnp.arange(-keys_len + 1, keys_len, 1.0) * total_kv_pooling if is_funnel_layer and separate_cls: # For pool_size 2 without separating cls we have got # [0][1][2][3][4][5][6][7] -> [01][23][45][67] # With separating cls we have got # [0][1][2][3][4][5][6][7] -> [0][12][34][56] # First group always will always consist of one token after pooling # instead of (pool_size) tokens. We need to add proper offset so # that our shift later on in calculating attention works properly cls_offset = (current_pooling_ratio - 1) * total_kv_pooling positions = positions + cls_offset return positions def Sinusoidal_Embeddings(positions): inv_freq = 1 / (10000**(jnp.arange(0.0, d_feature, 2.0) / d_feature)) sinusoid_freq = jnp.einsum('i,j->ij', positions, inv_freq) pos_emb = jnp.concatenate( [jnp.sin(sinusoid_freq), jnp.cos(sinusoid_freq)], axis=1) return pos_emb return cb.Serial( cb.Fn('Generate positions vectors', PositionsVectors, n_out=1), cb.Fn( 'Transform to sinusoidal encodings', Sinusoidal_Embeddings, n_out=1))
def test_fn_layer_difficult_n_out(self): with self.assertRaisesRegexp(ValueError, 'n_out'): # Determining the output of this layer is hard with dummies. cb.Fn(lambda x: np.concatencate([x, x], axis=4)) # Check that this layer works when n_out is set. layer = cb.Fn(lambda x: np.concatenate([x, x], axis=4), n_out=1) input_signature = ShapeDtype((2, 1, 2, 2, 3)) expected_shape = (2, 1, 2, 2, 6) output_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual(output_shape, expected_shape)
def PositionalEmbeddings(d_feature, separate_cls, total_kv_pooling): """Positional embeddings. Args: d_feature: Depth/dimensionality of feature embedding. separate_cls: True/False if we separate_cls in calculations. total_kv_pooling: Accumulated pool size of keys/values until this layer. Returns: a layer that based on queries, keys and accumulated pool size of keys/values until this layer calculates sinusoidal positional embeddings for relative attention calculations. """ def PositionsVectors(queries, keys): assert not separate_cls keys_len, queries_len = keys.shape[-2], queries.shape[-2] funnel_factor, is_upsampling = calc_funnel_ratio(keys_len, queries_len) if funnel_factor == 1: offset = keys_len - 1 positions = (jnp.arange(keys_len) - offset) * total_kv_pooling else: if is_upsampling: positions = jnp.arange(-queries_len + 1, queries_len, 1.0) else: positions = jnp.arange(-keys_len + 1, keys_len, 1.0) * total_kv_pooling return positions def Sinusoidal_Embeddings(positions): inv_freq = 1 / (10000**(jnp.arange(0.0, d_feature, 2.0) / d_feature)) sinusoid_freq = jnp.einsum('i,j->ij', positions, inv_freq) pos_emb = jnp.concatenate( [jnp.sin(sinusoid_freq), jnp.cos(sinusoid_freq)], axis=1) return pos_emb return cb.Serial( cb.Fn('Generate positions vectors', PositionsVectors, n_out=1), cb.Fn('Transform to sinusoidal encodings', Sinusoidal_Embeddings, n_out=1))
def test_fn_layer_example(self): layer = cb.Fn(lambda x, y: (x + y, np.concatenate([x, y], axis=0))) input_signature = (ShapeDtype((2, 7)), ShapeDtype((2, 7))) expected_shape = ((2, 7), (4, 7)) output_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual(output_shape, expected_shape) inp = (np.array([2]), np.array([3])) x, xs = layer(inp) self.assertEqual(int(x), 5) self.assertEqual([int(y) for y in xs], [2, 3])
def ShiftRightCls(cls_id): """Shifts right. Returns a layer that shifts input tokens to the right by one and inserts an cls token to the beginning like in BERT paper. Args: cls_id: id of the cls token in embedding dictionary. Returns: shift_right layer. """ def shift_right(x): pad_widths = [(0, 0)] * len(x.shape) pad_widths[1] = (1, 0) padded = jnp.pad( x, pad_widths, mode='constant', constant_values=x.dtype.type(cls_id)) return padded[:, :-1] return cb.Fn('ShiftRightCls()', shift_right)
def CreateAttentionMaskLayer(): """Creates attention mask layer. Returns a layer that based on queries, keys and accumulated pool size of keys/values until this layer calculates positional embeddings for causal relative attention calculations. Takes as input q, k, v and appends proper mask in the end. Causal attention uses masking to prevent a given sequence position from attending to positions greater than / following it. This is used, for example, when training autoregressive sequence models, or when decoding a sequence symbol by symbol. Returns: an attention mask layer. """ def calculate_mask(queries, keys): batch_size = queries.shape[0] keys_len, queries_len = keys.shape[-2], queries.shape[-2] funnel_factor, is_upsampling = calc_funnel_ratio(keys_len, queries_len) return _funnel_mask(batch_size, keys_len, queries_len, funnel_factor, is_upsampling) def _funnel_mask(batch_size, keys_len, queries_len, funnel_factor, is_upsampling): """Funnel mask. Args: batch_size: batch size. keys_len: keys length. queries_len: queries length. funnel_factor: funnel factor. is_upsampling: True or False. Returns: funnel mask. This function based on keys/queries lengths creates a triangle mask that prevents tokens from attending to positions following it. If funnel_factor is not equal to 1 due to funnel upsampling or downsampling it adjusts created mask for funnel attention by repeating each element funnel_factor times. This is because after funnel layer one token attends to funnel_factor different tokens in downsampling. During upsampling on the other hand funnel_factor tokens are attending to single token before upsampling. """ if funnel_factor != 1: if not is_upsampling: mask = jnp.tril( jnp.ones((queries_len, queries_len), dtype=jnp.bool_)) mask = jnp.repeat(mask, funnel_factor, axis=-1) else: mask = jnp.tril(jnp.ones((keys_len, keys_len), dtype=jnp.bool_)) mask = jnp.repeat(mask, funnel_factor, axis=-2) else: mask = jnp.tril( jnp.ones((queries_len, queries_len), dtype=jnp.bool_)) return jnp.repeat(mask[None, None, :, :], batch_size, axis=0) return cb.Branch( cb.Select([0]), cb.Select([1]), cb.Select([2]), cb.Fn('create attention mask layer', calculate_mask, n_out=1))
def test_fn_layer_fails_wrong_f(self): with self.assertRaisesRegexp(ValueError, 'default arg'): cb.Fn(lambda x, sth=None: x) with self.assertRaisesRegexp(ValueError, 'keyword arg'): cb.Fn(lambda x, **kwargs: x)