def forward(self, x, layer_states): """ Parameters ---------- x - layout = 'NT' Shape (batch_size, seq_length, C_in) - layout = 'TN' Shape (seq_length, batch_size, C_in) layer_states - layout = 'NT' Shape (2, batch_size, prev_len, C_in) - layout = 'TN' Shape (2, prev_len, batch_size, C_in) """ x = self.ln(x) if self._layout == 'NT': batch_axis, time_axis = 0, 1 prev_len = npx.shape_array(layer_states)[2] else: batch_axis, time_axis = 1, 0 prev_len = npx.shape_array(layer_states)[1] query, key, value = np.split(self.qkv(x), 3, axis=-1) if layer_states is not None: prev_key, prev_value = layer_states[0], layer_states[1] key = np.concatenate([prev_key, key], axis=time_axis) value = np.concatenate([prev_value, value], axis=time_axis) new_states = np.stack([key, value], axis=0) # gen mask query_pos = npx.arange_like(query, axis=time_axis) if prev_len is not None: query_pos = query_pos + prev_len key_pos = npx.arange_like(key, axis=time_axis) # (query_len, key_len) mask = (npx.reshape(key_pos, (1, -1)) <= npx.reshape(query_pos, (-1, 1))).astype( self._dtype) # broadcast to (batch_size, query_len, key_len) mask = npx.broadcast_like(np.expand_dims(mask, axis=0), query, lhs_axes=0, rhs_axes=batch_axis) query = npx.reshape(query, (-2, -2, self._num_heads, -1)) key = npx.reshape(key, (-2, -2, self._num_heads, -1)) value = npx.reshape(value, (-2, -2, self._num_heads, -1)) out, [_, attn_weight] = self.attention_cell(query, key, value, mask) out = self.out_proj(out) out = self.hidden_dropout(out) return out, new_states
def test_reshape(): A = np.ones((INT_OVERFLOW, 2)) A.attach_grad() with mx.autograd.record(): B = npx.reshape(A, (-5)) assert B.shape == (DOUBLE_INT_OVERFLOW, ) assert B[0] == 1 B.backward() assert A.grad.shape == (INT_OVERFLOW, 2) assert A.grad[0][0] == 1
def forward(self, step_data, past_states): mem_states, mem_valid_length, position, past_key_values = past_states step_hidden_states = self.model.input_embedding_layer(step_data) # NT: (B, d_model) -> (B, 1, d_model); TN: (B, d_model) -> (1, B, d_model) step_hidden_states = np.expand_dims(step_hidden_states, axis=self.model._time_axis) step_hidden_states, present_key_values = self.model.decoder.incremental_decode( step_hidden_states, position, past_key_values, mem_states, mem_valid_length) step_hidden_states = self.output_layer(step_hidden_states) # NT: (B, 1, vocab_size) -> (B, vocab_size); TN: (1, B, vocab_size) -> (B, vocab_size) step_hidden_states = npx.reshape(step_hidden_states, (-5, -1)) return step_hidden_states, (mem_states, mem_valid_length, position + 1, present_key_values)
def gen_rel_position(data, past_data=None, dtype=np.int32, layout='NT'): """Create a matrix of relative position for RelAttentionScoreCell. The relative position is defined as the index difference: `mem_i` - `query_j`. Note, though, that the implementation here makes sense in self-attention's setting, but not in cross-attention's. Hence, both `mem_i` and `query_j` are time indices from `data` (or, in incremental decoding's case, the concatenated sequence from the current stepwise `data` and the previous steps `past_data`). Parameters ---------- data The data. Under incremental decoding, seq_length = 1. - layout = 'NT' Shape (batch_size, seq_length, C) - layout = 'TN' Shape (seq_length, batch_size, C) past_data This is only used under incremental decoding. Stacked data from previous steps. dtype Data type of the mask layout Layout of the data + past_data Returns ------- relative_position : Shape (query_length, mem_length) where query_length = mem_length = seq_length """ time_axis = 1 if layout == 'NT' else 0 if past_data is None: position = npx.arange_like(data, axis=time_axis) else: # for incremental decoding only, where past data is of the shape: # NT(NTK): (B, L_seq, num_heads, n_kv) -> (B, L_seq, inner_dim) # TN(TNK): (L_seq, B, num_heads, n_kv) -> (L_seq, B, inner_dim) past_data = npx.reshape(past_data, (-2, -2, -5)) position = npx.arange_like( np.concatenate([past_data, data], axis=time_axis), axis=time_axis ) query_position = np.expand_dims(position, axis=-1) mem_position = np.expand_dims(position, axis=0) relative_position = mem_position - query_position return relative_position.astype(np.int32) # shape (L_seq, L_seq)
def add_vectors_by_position(data, increment, positions): """Scatter each batch with the given positions. data[i, positions[i, j], ...] += increment[i, j, ...] Parameters ---------- F data Input tensor of the array to be updated. Shape (batch_size, seq_length, ...) increment Input tensor of token ids Shape (batch_size, num_disp_position, ...) positions Input tensor of the positions. Shape (batch_size, num_disp_position). For each sample in the batch, the values in this tensor must not exceed the length of the sequence. Returns ------- out The updated result. Shape (batch_size, seq_length, ...) """ # Here, we use index_add to disperse the output from data: # Need to compute # out[i, masked_position[i, j], :] = in[i, j, :] # Thus, construct an indices with shape [2, batch_size * num_masked_position], where # indices[0, i * num_masked_position + j] = i # indices[1, i * num_masked_position + j] = masked_position[i, j] # And convert data to the shape of the (batch_size * num_masked_position, ) # Then, out = npx.index_add(data, indices, increment) positions = positions.astype(np.int32) # batch_idx.shape = (batch_size, 1) as [[0], [1], [2], ...] batch_idx = np.expand_dims(npx.arange_like(positions, axis=0), axis=1).astype(np.int32) batch_idx = batch_idx + np.zeros_like(positions) indices = np.stack([batch_idx.reshape((-1, )), positions.reshape((-1, ))]) out = npx.index_add(data, indices, npx.reshape(increment, (-5, -4))) return out
def update_vectors_by_position(data, val, positions): """ Update each batch with the given positions. Considered as a reversed process of "select_vectors_by_position", this is an operator similar to "add_vectors_by_position" that updates the results instead of adding. data[i, positions[i, j], :] = val[i, j, :] Parameters ---------- F data: Input tensor of the array to be updated. Shape (batch_size, seq_length) val Input tensor of token ids Shape (batch_size, num_disp_position) positions Input tensor of the positions. Shape (batch_size, num_disp_position). For each sample in the batch, the values in this tensor must not exceed the length of the sequence. Returns ------- out The updated result. Shape (batch_size, seq_length) """ positions = positions.astype(np.int32) # batch_idx.shape = (batch_size, 1) as [[0], [1], [2], ...] batch_idx = np.expand_dims(npx.arange_like(positions, axis=0), axis=1).astype(np.int32) batch_idx = batch_idx + np.zeros_like(positions) indices = np.stack([batch_idx.reshape((-1, )), positions.reshape((-1, ))]) out = npx.index_update(data, indices, npx.reshape(val, (-5, -4))) return out
def forward(self, rel_positions, query=None): """Forward function Parameters ---------- rel_positions The relative shifts. Shape (query_length, mem_length). Each element represents the shift between the :math:`i-th` element of query and the :math:`j-th` element of memory. query The query for computing the relative scores. The shape depends on the layout. If we use T5 attention, the query will not be used. Returns ------- rel_scores The relative attention scores Can have shape (batch_size, num_heads, query_length, mem_length) or (num_heads, query_length, mem_length) """ if self._method == 'transformer_xl' or self._method == 'shaw': assert query is not None, 'Must specify query if method={}'.format(self._method) if self._bidirectional: if self._max_distance is not None: rel_positions = np.clip(rel_positions, a_min=-self._max_distance, a_max=self._max_distance) else: if self._max_distance is not None: rel_positions = np.clip(rel_positions, a_min=0, a_max=self._max_distance) # uniq_rel.shape = (#uniq,), rev_index.shape = (L_q, L_m) uniq_rel, rev_index = np.unique(rel_positions, return_inverse=True) uniq_rel_pos_embed = self._rel_pos_embed(uniq_rel) if self._method == 'transformer_xl': uniq_rel_pos_embed = self._rel_proj(self._dropout_layer(uniq_rel_pos_embed)) # Shape (#uniq, K, C_q) uniq_rel_pos_embed = npx.reshape(uniq_rel_pos_embed, (-2, self._num_heads, self._head_query_units)) # Calculate the dot-product between query and the relative positional embeddings. # After the calculation, rel_score.shape = (L_q, #uniq, N, K) if self._layout == 'NKT': # query_for_rel: (N, K, L_q, C_q) if self._use_einsum: rel_score = np.einsum('bnid,jnd->ijbn', query, uniq_rel_pos_embed) else: rel_score = np.transpose( np.matmul(query, np.transpose(uniq_rel_pos_embed, (1, 2, 0))), (2, 3, 0, 1) ) elif self._layout == 'NTK': # query_for_rel: (N, L_q, K, C_q) if self._use_einsum: rel_score = np.einsum('bind,jnd->ijbn', query, uniq_rel_pos_embed) else: rel_score = np.transpose( np.matmul(np.swapaxes(query, 1, 2), np.transpose(uniq_rel_pos_embed, (1, 2, 0))), (2, 3, 0, 1) ) elif self._layout == 'TNK': # query_for_rel: (L_q, N, K, C_q) if self._use_einsum: rel_score = np.einsum('ibnd,jnd->ijbn', query, uniq_rel_pos_embed) else: rel_score = np.transpose( np.matmul(np.transpose(query, (1, 2, 0, 3)), np.transpose(uniq_rel_pos_embed, (1, 2, 0))), (2, 3, 0, 1) ) else: raise NotImplementedError # We use gather_nd to select the elements # TODO(sxjscience) Use advanced indexing once available rev_index = npx.reshape_like(rev_index, rel_positions).astype(np.int32) query_idx = np.expand_dims(npx.arange_like(rel_positions, axis=0).astype(np.int32), axis=-1) + np.zeros_like(rev_index) rel_score = npx.gather_nd(rel_score, np.stack([query_idx, rev_index])) rel_score = np.transpose(rel_score, (2, 3, 0, 1)) elif self._method == 't5': # shape is (K, L_q, L_m) rel_score = self._rel_pos_embed(rel_positions).transpose((2, 0, 1)) else: raise NotImplementedError return rel_score
def gen_self_attn_mask(data, valid_length=None, dtype: type = np.float32, attn_type: str = 'full', layout: str = 'NT'): """Generate the mask used for the encoder, i.e, self-attention. In our implementation, 1 --> not masked, 0 --> masked Let's consider the data with two samples: .. code-block:: none data = [['I', 'can', 'now', 'use', 'numpy', 'in', 'Gluon@@', 'NLP' ], ['May', 'the', 'force', 'be', 'with', 'you', '<PAD>', '<PAD>']] valid_length = [8, 6] - attn_type = 'causal' Each token will attend to itself + the tokens before. It will not attend to tokens in the future. For our example, the mask of the first sample is .. code-block:: none ['I', 'can', 'now', 'use', 'numpy', 'in', 'Gluon@@', 'NLP'] 'I': 1, 0, 0, 0, 0, 0, 0, 0 'can': 1, 1, 0, 0, 0, 0, 0, 0 'now': 1, 1, 1, 0, 0, 0, 0, 0 'use': 1, 1, 1, 1, 0, 0, 0, 0 'numpy': 1, 1, 1, 1, 1, 0, 0, 0 'in': 1, 1, 1, 1, 1, 1, 0, 0 'Gluon@@': 1, 1, 1, 1, 1, 1, 1, 0 'NLP': 1, 1, 1, 1, 1, 1, 1, 1 The mask of the second sample is .. code-block:: none ['May', 'the', 'force', 'be', 'with', 'you', '<PAD>', '<PAD>'] 'May': 1, 0, 0, 0, 0, 0, 0, 0 'the': 1, 1, 0, 0, 0, 0, 0, 0 'force': 1, 1, 1, 0, 0, 0, 0, 0 'be': 1, 1, 1, 1, 0, 0, 0, 0 'with': 1, 1, 1, 1, 1, 0, 0, 0 'you': 1, 1, 1, 1, 1, 1, 0, 0 '<PAD>': 0, 0, 0, 0, 0, 0, 0, 0 '<PAD>': 0, 0, 0, 0, 0, 0, 0, 0 - attn_type = 'full' Each token will attend to both the tokens before and in the future For our example, the mask of the first sample is .. code-block:: none ['I', 'can', 'now', 'use', 'numpy', 'in', 'Gluon@@', 'NLP'] 'I': 1, 1, 1, 1, 1, 1, 1, 1 'can': 1, 1, 1, 1, 1, 1, 1, 1 'now': 1, 1, 1, 1, 1, 1, 1, 1 'use': 1, 1, 1, 1, 1, 1, 1, 1 'numpy': 1, 1, 1, 1, 1, 1, 1, 1 'in': 1, 1, 1, 1, 1, 1, 1, 1 'Gluon@@': 1, 1, 1, 1, 1, 1, 1, 1 'NLP': 1, 1, 1, 1, 1, 1, 1, 1 The mask of the second sample is .. code-block:: none ['May', 'the', 'force', 'be', 'with', 'you', '<PAD>', '<PAD>'] 'May': 1, 1, 1, 1, 1, 1, 0, 0 'the': 1, 1, 1, 1, 1, 1, 0, 0 'force': 1, 1, 1, 1, 1, 1, 0, 0 'be': 1, 1, 1, 1, 1, 1, 0, 0 'with': 1, 1, 1, 1, 1, 1, 0, 0 'you': 1, 1, 1, 1, 1, 1, 0, 0 '<PAD>': 0, 0, 0, 0, 0, 0, 0, 0 '<PAD>': 0, 0, 0, 0, 0, 0, 0, 0 Parameters ---------- data The data. - layout = 'NT' Shape (batch_size, seq_length, C) - layout = 'TN' Shape (seq_length, batch_size, C) valid_length Shape (batch_size,) dtype Data type of the mask attn_type Can be 'full' or 'causal' layout The layout of the data Returns ------- mask Shape (batch_size, seq_length, seq_length) """ if layout == 'NT': batch_axis, time_axis = 0, 1 elif layout == 'TN': batch_axis, time_axis = 1, 0 else: raise NotImplementedError('Unsupported layout={}'.format(layout)) if attn_type == 'full': if valid_length is not None: valid_length = valid_length.astype(dtype) steps = npx.arange_like(data, axis=time_axis) # (seq_length,) mask1 = (npx.reshape(steps, (1, 1, -1)) < npx.reshape(valid_length, (-2, 1, 1))) mask2 = (npx.reshape(steps, (1, -1, 1)) < npx.reshape(valid_length, (-2, 1, 1))) mask = mask1 * mask2 else: # TODO(sxjscience) optimize seq_len_ones = np.ones_like(npx.arange_like(data, axis=time_axis)) # (seq_length,) batch_ones = np.ones_like(npx.arange_like(data, axis=batch_axis)) # (batch_size,) mask = batch_ones.reshape((-1, 1, 1)) * seq_len_ones.reshape((1, -1, 1))\ * seq_len_ones.reshape((1, 1, -1)) elif attn_type == 'causal': steps = npx.arange_like(data, axis=time_axis) # mask: (seq_length, seq_length) # batch_mask: (batch_size, seq_length) mask = (np.expand_dims(steps, axis=0) <= np.expand_dims(steps, axis=1)).astype(dtype) if valid_length is not None: valid_length = valid_length.astype(dtype) batch_mask = (np.expand_dims(steps, axis=0) < np.expand_dims(valid_length, axis=-1)).astype(dtype) mask = mask * np.expand_dims(batch_mask, axis=-1) else: batch_ones = np.ones_like(npx.arange_like(data, axis=batch_axis), dtype=dtype) # (batch_size,) mask = mask * batch_ones.reshape((-1, 1, 1)) else: raise NotImplementedError return mask.astype(np.bool)
def multi_head_dot_attn(query, key, value, mask=None, edge_scores=None, dropout: float = 0.0, scaled: bool = True, normalized: bool = False, eps: float = 1E-6, query_head_units: Optional[int] = None, layout: str = 'NKT', use_einsum: bool = False): """Multihead dot product attention between the query, key, value. scaled is False, normalized is False: D(h_q, h_k) = <h_q, h_k> scaled is True, normalized is False: D(h_q, h_k) = <h_q, h_k> / sqrt(dim_q) scaled is False, normalized is True: D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||> scaled is True, normalized is True: D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||> / sqrt(dim_q) If edge_scores is provided, we will calcualte the attention as scores = D(h_q, h_k) + EdgeScore_{q, k} Parameters ---------- query Query. The shape depends on the layout - layout is 'NKT' Shape (batch_size, num_heads, query_length, key_dim) - layout is 'NTK' Shape (batch_size, query_length, num_heads, key_dim) - layout is 'TNK' Shape (query_length, batch_size, num_heads, key_dim) key Key. The shape depends on the layout - layout is 'NKT' Shape (batch_size, num_heads, mem_length, key_dim) - layout is 'NTK' Shape (batch_size, mem_length, num_heads, key_dim) - layout is 'TNK' Shape (mem_length, batch_size, num_heads, key_dim) value Value. The shape depends on the layout - layout is 'NKT' Shape (batch_size, num_heads, mem_length, value_dim) - layout is 'NTK' Shape (batch_size, mem_length, num_heads, value_dim) - layout is 'TNK' Shape (mem_length, batch_size, num_heads, value_dim) mask Mask between query and memory. Shape (batch_size, query_length, mem_length) edge_scores The edge attention score. Shape can be any shape that is broadcastable to (batch_size, num_heads, query_length, mem_length) dropout Dropout rate scaled Whether to divide the attention weights by the sqrt of the query dimension. This is first proposed in "[NIPS2017] Attention is all you need.":: .. code-block:: none score = <h_q, h_k> / sqrt(dim_q) normalized If turned on, the cosine distance is used, i.e:: .. code-block:: none score = <h_q / ||h_q||, h_k / ||h_k||> eps The epsilon value used in L2 normalization query_head_units The units of each query head. If it's empty, we will estimate it via the shape_array of the query. layout This stands for the layout of the attention cell. The shape of the input/output will depend on the layout. Currently, we support 'NKT', 'NTK' and 'TNK' in which 'N' means the batch_size, 'K' means the head, and 'T' means the length dimension. use_einsum Whether to use einsum for the computation Returns ------- context_vec - layout is 'NKT' or 'NTK' Shape (batch_size, query_length, num_heads * value_units) - layout is 'TNK' Shape (query_length, batch_size, num_heads * value_units) additional_info scores: Shape (batch_size, num_head, query_length, mem_length) attn_weight: Shape (batch_size, num_head, query_length, mem_length) """ # TODO(sxjscience) Profile layout if normalized: query = l2_normalize(query, axis=-1, eps=eps) key = l2_normalize(key, axis=-1, eps=eps) if scaled: if query_head_units is None: raise NotImplementedError('You will need to specify query_head_units!') else: scale = math.sqrt(query_head_units) else: scale = None if layout == 'NKT': # 1. Expand the dimension of the mask: # (B, L_query, L_mem) --> (B, 1, L_query, L_mem) if mask is not None: mask = np.expand_dims(mask, axis=1).astype(np.bool) # 2. Calculate the attention weights # Score: (B, N, L_query, C_Q) X (B, N, L_mem, C_Q) --> (B, N, L_query, L_mem) scores = npx.batch_dot(query, key, transpose_b=True) if edge_scores is not None: scores = scores + edge_scores attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale) attn_weights = npx.dropout(attn_weights, p=dropout) # 3. Calculate the context vector # (B, N, L_query, L_mem) X (B, N, L_mem, C_V) --> (B, L_query, N * C_V) if use_einsum: context_vec = np.einsum('bnij,bnjc->binc', attn_weights, value) else: context_vec = npx.batch_dot(attn_weights, value).transpose((0, 2, 1, 3)) context_vec = npx.reshape(context_vec, (-2, -2, -1)) elif layout == 'NTK': # 1. Expand the dimension of the mask: # (B, L_query, L_mem) --> (B, 1, L_query, L_mem) if mask is not None: mask = np.expand_dims(mask, axis=1).astype(np.bool) # 2. Calculate the attention weights # Score: (B, L_query, N, C_Q) X (B, L_mem, N, C_Q) --> (B, N, L_query, L_mem) if use_einsum: scores = np.einsum('binc,bjnc->bnij', query, key) else: scores = npx.batch_dot(np.swapaxes(query, 1, 2), np.swapaxes(key, 1, 2), transpose_b=True) if edge_scores is not None: scores = scores + edge_scores attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale) attn_weights = npx.dropout(attn_weights, p=dropout) # 3. Calculate the context vector # (B, N, L_query, L_mem) X (B, L_mem, N, C_V) --> (B, L_query, N * C_V) if use_einsum: context_vec = np.einsum('bnij,bjnc->binc', attn_weights, value) else: context_vec = npx.batch_dot(attn_weights, np.swapaxes(value, 1, 2)).transpose((0, 2, 1, 3)) context_vec = npx.reshape(context_vec, (-2, -2, -1)) elif layout == 'TNK': # 1. Expand the dimension of the mask: # (B, L_query, L_mem) --> (B, 1, L_query, L_mem) if mask is not None: mask = np.expand_dims(mask, axis=1).astype(np.bool) # 2. Calculate the attention weights # Score: (L_query, B, N, C_Q) X (L_mem, B, N, C_Q) --> (B, N, L_query, L_mem) # This layout structure can be implemented very efficiently because B, N are consecutive # to each other. To have a clear picture of what's happening, we may consider the # (i, j)th element of the output # out[i, j, :, :] = query[:, i, j, :] X key[:, i, j, :].T, which is just one GEMM call # We can thus implement the whole kernel via a single call of batched GEMM with stride. if use_einsum: scores = np.einsum('ibnc,jbnc->bnij', query, key) else: scores = npx.batch_dot(query.transpose((1, 2, 0, 3)), key.transpose((1, 2, 3, 0))) if edge_scores is not None: scores = scores + edge_scores attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale) attn_weights = npx.dropout(attn_weights, p=dropout) # 3. Calculate the context vector # (B, N, L_query, L_mem) X (L_mem, B, N, C_V) --> (L_query, B, N * C_V) # Again, we can implement it via a single call to batched GEMM with stride. # Shape (B, N, L_query, C_V) if use_einsum: context_vec = np.einsum('bnij,jbnc->ibnc', attn_weights, value) else: context_vec = npx.batch_dot(attn_weights, value.transpose((1, 2, 0, 3))).transpose((2, 0, 1, 3)) context_vec = npx.reshape(context_vec, (-2, -2, -1)) else: raise NotImplementedError('layout="{}" is not supported! ' 'We only support layout = "NKT", "NTK", and "TNK".' .format(layout)) return context_vec, [scores, attn_weights]
def gen_mem_attn_mask(mem, mem_valid_length, data, data_valid_length=None, dtype=np.float32, layout: str = 'NT'): """Generate the mask used for the decoder. All query slots are attended to the memory slots. In our implementation, 1 --> not masked, 0 --> masked Let's consider the data + mem with a batch of two samples: .. code-block:: none mem = [['I', 'can', 'now', 'use'], ['May', 'the', 'force', '<PAD>']] mem_valid_length = [4, 3] data = [['numpy', 'in', 'Gluon@@', 'NLP' ], ['be', 'with', 'you', '<PAD>']] data_valid_length = [4, 3] For our example, the mask of the first sample is .. code-block:: none ['I', 'can', 'now', 'use'] 'numpy': 1, 1, 1, 1 'in': 1, 1, 1, 1 'Gluon@@': 1, 1, 1, 1 'NLP': 1, 1, 1, 1 The mask of the second sample is .. code-block:: none ['be', 'with', 'you', '<PAD>'] 'May': 1, 1, 1, 0 'the': 1, 1, 1, 0 'force': 1, 1, 1, 0 '<PAD>': 0, 0, 0, 0 Parameters ---------- mem - layout = 'NT' Shape (batch_size, mem_length, C_mem) - layout = 'TN' Shape (mem_length, batch_size, C_mem) mem_valid_length : Shape (batch_size,) data - layout = 'NT' Shape (batch_size, query_length, C_data) - layout = 'TN' Shape (query_length, batch_size, C_data) data_valid_length : Shape (batch_size,) dtype Data type of the mask layout Layout of the data + mem tensor Returns ------- mask : Shape (batch_size, query_length, mem_length) """ if layout == 'NT': batch_axis, time_axis = 0, 1 elif layout == 'TN': batch_axis, time_axis = 1, 0 else: raise NotImplementedError('Unsupported layout={}'.format(layout)) mem_valid_length = mem_valid_length.astype(dtype) mem_steps = npx.arange_like(mem, axis=time_axis) # (mem_length,) data_steps = npx.arange_like(data, axis=time_axis) # (query_length,) mem_mask = (npx.reshape(mem_steps, (1, 1, -1)) < npx.reshape(mem_valid_length, (-2, 1, 1))).astype(dtype) # (B, 1, mem_length) if data_valid_length is not None: data_valid_length = data_valid_length.astype(dtype) data_mask = (npx.reshape(data_steps, (1, -1, 1)) < npx.reshape(data_valid_length, (-2, 1, 1))).astype(dtype) # (B, query_length, 1) mask = mem_mask * data_mask else: query_length_ones = np.ones_like(data_steps) mask = query_length_ones.reshape((1, -1, 1)) * mem_mask return mask.astype(np.bool)
def transpose_for_scores(self, x): # NT -> NTK: (B, L_seq, inner_dim) -> (B, L_seq, num_heads, n_kv) # TN -> TNK: (L_seq, B, inner_dim) -> (L_seq, B, num_heads, n_kv) return npx.reshape(x, (-2, -2, self._num_heads, -1))
def forward(self, data, mem, rel_positions, mask, query_r_bias, query_k_bias): """ Parameters ---------- F data The input data. layout = 'NT': Shape (batch_size, query_length, units) layout = 'TN': Shape (query_length, batch_size, units) mem The memory. layout = 'NT': Shape (batch_size, mem_length, units) layout = 'TN': Shape (mem_length, batch_size, units) rel_positions The relative positions between data and [mem, data] Shape (query_length, mem_length + query_length). A positive value means that query is after the memory, i.e., query_location - mem_location. mask Mask between the query and the memory + query. 1--> will be used, 0 --> won't be used Shape (batch_size, query_length, mem_length + query_length) query_r_bias The query bias for calculating the relative scores Shape (num_heads, query_head_units) query_k_bias The key bias for calculating the relative scores. Shape (num_heads, query_head_units) Returns ------- out - layout = 'NT' Shape (batch_size, query_length, units) - layout = 'TN' Shape (query_length, batch_size, units) """ if self._layout == 'NT': context = np.concatenate([mem, data], axis=1) elif self._layout == 'TN': context = np.concatenate([mem, data], axis=0) else: raise NotImplementedError if self._pre_norm: query = self.attn_query(self.layer_norm(data)) key_value = self.attn_kv(self.layer_norm(context)) key, value = np.split(key_value, 2, axis=-1) else: query = self.attn_query(data) key_value = self.attn_kv(context) key, value = np.split(key_value, 2, axis=-1) query = npx.reshape(query, (-2, -2, self._num_heads, -1)) key = npx.reshape(key, (-2, -2, self._num_heads, -1)) value = npx.reshape(value, (-2, -2, self._num_heads, -1)) # Compute attention rel_score = self.rel_pos_score_cell(rel_positions, query + query_r_bias) out, _ = self.attn_cell(query + query_k_bias, key, value, mask, rel_score) out = self.dropout_layer(out) if self._pre_norm: out = data + out else: out = self.layer_norm(data + out) out = self.ffn(out) return out
def forward(self, data, indices): mask = indices < 3 data = npx.reshape(data, (-1, -2), reverse=True) mask = np.reshape(mask, (-1, )) sel = nd.np._internal.boolean_mask(data, mask) return sel
def forward(self, data, attn_mask): """ Parameters ---------- F data - layout = 'NT' Shape (batch_size, seq_length, C_in) - layout = 'TN' Shape (seq_length, batch_size, C_in) attn_mask The attention mask Shape (batch_size, seq_length, seq_length) Returns ------- out - layout = 'NT' Shape (batch_size, seq_length, C_out) - layout = 'TN' Shape (seq_length, batch_size, C_out) attn_weight Shape (batch_size, seq_length, seq_length) """ if self._use_bottleneck: bn_proj = self.in_bottleneck_proj(data) bn_proj = self.in_bottleneck_ln(bn_proj) input = bn_proj if self._bottleneck_strategy == 'qk_sharing': # for Mobile Bert qk_shared = self.shared_qk(data) qk_shared = self.shared_qk_ln(qk_shared) query = qk_shared key = qk_shared value = data elif self._bottleneck_strategy == 'from_bottleneck': # for Mobile Bert Tiny query = bn_proj key = bn_proj value = bn_proj elif self._bottleneck_strategy == 'from_input': query = data key = data value = data else: raise NotImplementedError else: input = data query = data key = data value = data query = npx.reshape(self.attn_query(query), (-2, -2, self._num_heads, -1)) key = npx.reshape(self.attn_key(key), (-2, -2, self._num_heads, -1)) value = npx.reshape(self.attn_value(value), (-2, -2, self._num_heads, -1)) out, [_, attn_weight] = self.attention_cell(query, key, value, attn_mask) out = self.attention_proj(out) if not self._use_bottleneck: out = self.dropout_layer(out) out = out + input out = self.layer_norm(out) for ffn_idx in range(self._num_stacked_ffn): ffn = self.stacked_ffn[ffn_idx] out = ffn(out) if self._use_bottleneck: out = self.out_bottleneck_proj(out) out = self.dropout_layer(out) out = out + data out = self.out_bottleneck_ln(out) return out, attn_weight
def transpose_for_scores(x): return npx.reshape(x, (-2, -2, self._num_heads, -1))