def forward(self, x, states): """ Parameters ---------- x - layout = 'NT' Shape (batch_size, seq_length) - layout = 'TN' Shape (seq_length, batch_size) states The previous states - layout = 'NT' Shape (num_layers, 2, batch_size, prev_len, C_in)] - layout = 'TN' Shape (num_layers, 2, prev_len, batch_size, C_in)] Returns ------- new_x Output - layout = 'NT' Shape (batch_size, seq_length, C_out) - layout = 'TN' Shape (seq_length, batch_size, C_out) new_states The new states - layout = 'NT' Shape (num_layers, 2, batch_size, prev_len + seq_length, C_in) - layout = 'TN' Shape (num_layers, 2, prev_len + seq_length, batch_size, C_in) """ prev_len = npx.shape_array(states)[3] if self._layout == 'NT' else \ npx.shape_array(states)[2] x = self.get_initial_embedding(x, prev_len) if self._layout != self._compute_layout: x = np.swapaxes(x, 0, 1) states = np.swapaxes(states, 2, 3) new_states = [] for layer_idx in range(self._num_layers): layer_states = None if states is None else states[layer_idx] x, new_layer_states = self._layers[layer_idx](x, layer_states) new_states.append(new_layer_states) new_states = np.stack(new_states, axis=0) x = self._final_ln(x) if self._layout != self._compute_layout: x = np.swapaxes(x, 0, 1) new_states = np.swapaxes(new_states, 2, 3) return x, new_states
def forward(self, x, layer_states): """ Parameters ---------- x - layout = 'NT' Shape (batch_size, seq_length, C_in) - layout = 'TN' Shape (seq_length, batch_size, C_in) layer_states - layout = 'NT' Shape (2, batch_size, prev_len, C_in) - layout = 'TN' Shape (2, prev_len, batch_size, C_in) """ x = self.ln(x) if self._layout == 'NT': batch_axis, time_axis = 0, 1 prev_len = npx.shape_array(layer_states)[2] else: batch_axis, time_axis = 1, 0 prev_len = npx.shape_array(layer_states)[1] query, key, value = np.split(self.qkv(x), 3, axis=-1) if layer_states is not None: prev_key, prev_value = layer_states[0], layer_states[1] key = np.concatenate([prev_key, key], axis=time_axis) value = np.concatenate([prev_value, value], axis=time_axis) new_states = np.stack([key, value], axis=0) # gen mask query_pos = npx.arange_like(query, axis=time_axis) if prev_len is not None: query_pos = query_pos + prev_len key_pos = npx.arange_like(key, axis=time_axis) # (query_len, key_len) mask = (npx.reshape(key_pos, (1, -1)) <= npx.reshape(query_pos, (-1, 1))).astype( self._dtype) # broadcast to (batch_size, query_len, key_len) mask = npx.broadcast_like(np.expand_dims(mask, axis=0), query, lhs_axes=0, rhs_axes=batch_axis) query = npx.reshape(query, (-2, -2, self._num_heads, -1)) key = npx.reshape(key, (-2, -2, self._num_heads, -1)) value = npx.reshape(value, (-2, -2, self._num_heads, -1)) out, [_, attn_weight] = self.attention_cell(query, key, value, mask) out = self.out_proj(out) out = self.hidden_dropout(out) return out, new_states
def test_shape_array(): A = np.zeros((INT_OVERFLOW, 2)) A.attach_grad() with mx.autograd.record(): B = npx.shape_array(A) assert B[0] == INT_OVERFLOW and B[1] == 2 B.backward() assert A.grad.shape == (INT_OVERFLOW, 2) assert A.grad[0][0] == 0
def multi_head_dot_attn(query, key, value, mask=None, edge_scores=None, dropout: float = 0.0, scaled: bool = True, normalized: bool = False, eps: float = 1E-6, query_head_units: Optional[int] = None, layout: str = 'NKT', use_einsum: bool = False, dtype=np.float32): """Multihead dot product attention between the query, key, value. scaled is False, normalized is False: D(h_q, h_k) = <h_q, h_k> scaled is True, normalized is False: D(h_q, h_k) = <h_q, h_k> / sqrt(dim_q) scaled is False, normalized is True: D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||> scaled is True, normalized is True: D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||> / sqrt(dim_q) If edge_scores is provided, we will calcualte the attention as scores = D(h_q, h_k) + EdgeScore_{q, k} Parameters ---------- query Query. The shape depends on the layout - layout is 'NKT' Shape (batch_size, num_heads, query_length, key_dim) - layout is 'NTK' Shape (batch_size, query_length, num_heads, key_dim) - layout is 'TNK' Shape (query_length, batch_size, num_heads, key_dim) key Key. The shape depends on the layout - layout is 'NKT' Shape (batch_size, num_heads, mem_length, key_dim) - layout is 'NTK' Shape (batch_size, mem_length, num_heads, key_dim) - layout is 'TNK' Shape (mem_length, batch_size, num_heads, key_dim) value Value. The shape depends on the layout - layout is 'NKT' Shape (batch_size, num_heads, mem_length, value_dim) - layout is 'NTK' Shape (batch_size, mem_length, num_heads, value_dim) - layout is 'TNK' Shape (mem_length, batch_size, num_heads, value_dim) mask Mask between query and memory. Shape (batch_size, query_length, mem_length) edge_scores The edge attention score. Shape can be any shape that is broadcastable to (batch_size, num_heads, query_length, mem_length) dropout Dropout rate scaled Whether to divide the attention weights by the sqrt of the query dimension. This is first proposed in "[NIPS2017] Attention is all you need.":: score = <h_q, h_k> / sqrt(dim_q) normalized If turned on, the cosine distance is used, i.e:: score = <h_q / ||h_q||, h_k / ||h_k||> eps The epsilon value used in L2 normalization query_head_units The units of each query head. If it's empty, we will estimate it via the shape_array of the query. layout This stands for the layout of the attention cell. The shape of the input/output will depend on the layout. Currently, we support 'NKT', 'NTK' and 'TNK' in which 'N' means the batch_size, 'K' means the head, and 'T' means the length dimension. use_einsum Whether to use einsum for the computation Returns ------- context_vec - layout is 'NKT' or 'NTK' Shape (batch_size, query_length, num_heads * value_units) - layout is 'TNK' Shape (query_length, batch_size, num_heads * value_units) additional_info scores: Shape (batch_size, num_head, query_length, mem_length) attn_weight: Shape (batch_size, num_head, query_length, mem_length) """ # TODO(sxjscience) Profile layout if normalized: query = l2_normalize(query, axis=-1, eps=eps) key = l2_normalize(key, axis=-1, eps=eps) if scaled: if query_head_units is None: query_shape = npx.shape_array(query) scale = np.sqrt(query_shape[-1]) else: scale = math.sqrt(query_head_units) else: scale = None if layout == 'NKT': # 1. Expand the dimension of the mask: # (B, L_query, L_mem) --> (B, 1, L_query, L_mem) if mask is not None: mask = np.expand_dims(mask, axis=1) # 2. Calculate the attention weights # Score: (B, N, L_query, C_Q) X (B, N, L_mem, C_Q) --> (B, N, L_query, L_mem) scores = npx.batch_dot(query, key, transpose_b=True) if edge_scores is not None: scores = scores + edge_scores if scaled: scores = scores / scale attn_weights = masked_softmax(scores, mask, dtype=dtype, axis=-1) attn_weights = npx.dropout(attn_weights, p=dropout) # 3. Calculate the context vector # (B, N, L_query, L_mem) X (B, N, L_mem, C_V) --> (B, L_query, N * C_V) if use_einsum: context_vec = np.einsum('bnij,bnjc->binc', attn_weights, value) else: context_vec = npx.batch_dot(attn_weights, value).transpose( (0, 2, 1, 3)) context_vec = npx.reshape(context_vec, (-2, -2, -1)) elif layout == 'NTK': # 1. Expand the dimension of the mask: # (B, L_query, L_mem) --> (B, 1, L_query, L_mem) if mask is not None: mask = np.expand_dims(mask, axis=1) # 2. Calculate the attention weights # Score: (B, L_query, N, C_Q) X (B, L_mem, N, C_Q) --> (B, N, L_query, L_mem) if use_einsum: scores = np.einsum('binc,bjnc->bnij', query, key) else: scores = npx.batch_dot(np.swapaxes(query, 1, 2), np.swapaxes(key, 1, 2), transpose_b=True) if edge_scores is not None: scores = scores + edge_scores if scaled: scores = scores / scale attn_weights = masked_softmax(scores, mask, dtype=dtype) attn_weights = npx.dropout(attn_weights, p=dropout) # 3. Calculate the context vector # (B, N, L_query, L_mem) X (B, L_mem, N, C_V) --> (B, L_query, N * C_V) if use_einsum: context_vec = np.einsum('bnij,bjnc->binc', attn_weights, value) else: context_vec = npx.batch_dot(attn_weights, np.swapaxes(value, 1, 2)).transpose( (0, 2, 1, 3)) context_vec = npx.reshape(context_vec, (-2, -2, -1)) elif layout == 'TNK': # 1. Expand the dimension of the mask: # (B, L_query, L_mem) --> (B, 1, L_query, L_mem) if mask is not None: mask = np.expand_dims(mask, axis=1) # 2. Calculate the attention weights # Score: (L_query, B, N, C_Q) X (L_mem, B, N, C_Q) --> (B, N, L_query, L_mem) # This layout structure can be implemented very efficiently because B, N are consecutive # to each other. To have a clear picture of what's happening, we may consider the # (i, j)th element of the output # out[i, j, :, :] = query[:, i, j, :] X key[:, i, j, :].T, which is just one GEMM call # We can thus implement the whole kernel via a single call of batched GEMM with stride. if use_einsum: scores = np.einsum('ibnc,jbnc->bnij', query, key) else: scores = npx.batch_dot(query.transpose((1, 2, 0, 3)), key.transpose((1, 2, 3, 0))) if edge_scores is not None: scores = scores + edge_scores if scaled: scores = scores / scale attn_weights = masked_softmax(scores, mask, dtype=dtype) attn_weights = npx.dropout(attn_weights, p=dropout) # 3. Calculate the context vector # (B, N, L_query, L_mem) X (L_mem, B, N, C_V) --> (L_query, B, N * C_V) # Again, we can implement it via a single call to batched GEMM with stride. # Shape (B, N, L_query, C_V) if use_einsum: context_vec = np.einsum('bnij,jbnc->ibnc', attn_weights, value) else: context_vec = npx.batch_dot(attn_weights, value.transpose( (1, 2, 0, 3))).transpose( (2, 0, 1, 3)) context_vec = npx.reshape(context_vec, (-2, -2, -1)) else: raise NotImplementedError( 'layout="{}" is not supported! ' 'We only support layout = "NKT", "NTK", and "TNK".'.format(layout)) return context_vec, [scores, attn_weights]
def dot_attn_score(query, key, scaled=True, normalized=False, eps=1E-6, layout='NT'): """The inner function call to calculate the score used in dot-product attention. We support multiple leading batch dimensions. scaled is True: D(h_q, h_k) = <h_q, h_k> / sqrt(dim_q) normalized is True: D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||> both scaled and normalized: D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||> / sqrt(dim_q) Parameters ---------- query : symbol or ndarray - layout is 'NT' (B0, ..., BN, query_length, query_dim) - layout is 'TN' (query_length, B0, ..., BN, query_dim) key : symbol or ndarray - layout is 'NT' (B0, ..., BN, key_length, key_dim) - layout is 'TN' (key_length, B0, ..., BN, key_dim) scaled : bool Whether to divide the query by the square-root of the query_dim If True: D(h_q, h_k) = <h_q, h_k> / sqrt(dim_q) normalized : bool Whether to normalize the query and the key embeddings If True: D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||> eps : float The epsilon used in the normalization layout The layout of the layer. Can be 'TN' or 'NT'. Returns ------- scores : symbol or ndarray (B0, ..., BN, query_length, key_length) """ if normalized: query = l2_normalize(query, -1, eps=eps) key = l2_normalize(key, -1, eps=eps) if scaled: query_shape = npx.shape_array(query) # TODO(sxjscience) Remove .astype(np.float32). # Wait for https://github.com/apache/incubator-mxnet/issues/18084 query_units = query_shape[-1].astype(np.float32) query = query / np.sqrt(query_units) if layout == 'NT': scores = npx.batch_dot(query, key, transpose_b=True) else: raise NotImplementedError( 'layout={} is not supported.' ' Currently, only layout = "NT" is implemented!'.format(layout)) return scores