def get_end_logits(self, contextual_embedding, start_positions, p_mask): """ Parameters ---------- contextual_embedding Shape (batch_size, sequence_length, C) start_positions Shape (batch_size, N) We process multiple candidates simultaneously p_mask Shape (batch_size, sequence_length) Returns ------- end_logits Shape (batch_size, N, sequence_length) """ # Select the features at the start_positions # start_feature will have shape (batch_size, N, C) start_features = select_vectors_by_position(contextual_embedding, start_positions) # Concatenate the start_feature and the contextual_embedding contextual_embedding = np.expand_dims(contextual_embedding, axis=1) # (B, 1, T, C) start_features = np.expand_dims(start_features, axis=2) # (B, N, 1, C) concat_features = np.concatenate([npx.broadcast_like(start_features, contextual_embedding, 2, 2), npx.broadcast_like(contextual_embedding, start_features, 1, 1)], axis=-1) # (B, N, T, 2C) end_scores = self.end_scores(concat_features) end_scores = np.squeeze(end_scores, -1) end_logits = masked_logsoftmax(end_scores, mask=np.expand_dims(p_mask, axis=1), axis=-1) return end_logits
def test_broadcast_like(): A = np.ones((1, 2)) B = np.zeros((INT_OVERFLOW, 2)) A.attach_grad() with mx.autograd.record(): C = npx.broadcast_like(A, B) assert C.shape == (INT_OVERFLOW, 2) assert C[0][0] == 1 C.backward() assert A.grad.shape == (1, 2) with mx.autograd.record(): C = npx.broadcast_like(A.reshape(2, 1), B.T) assert C.shape == (2, INT_OVERFLOW) assert C[0][0] == 1 C.backward() assert A.grad.shape == (1, 2) assert_almost_equal(A.grad[0][0], np.array([INT_OVERFLOW]), \ rtol=1e-3, atol=1e-5)
def forward(self, x, layer_states): """ Parameters ---------- x - layout = 'NT' Shape (batch_size, seq_length, C_in) - layout = 'TN' Shape (seq_length, batch_size, C_in) layer_states - layout = 'NT' Shape (2, batch_size, prev_len, C_in) - layout = 'TN' Shape (2, prev_len, batch_size, C_in) """ x = self.ln(x) if self._layout == 'NT': batch_axis, time_axis = 0, 1 prev_len = npx.shape_array(layer_states)[2] else: batch_axis, time_axis = 1, 0 prev_len = npx.shape_array(layer_states)[1] query, key, value = np.split(self.qkv(x), 3, axis=-1) if layer_states is not None: prev_key, prev_value = layer_states[0], layer_states[1] key = np.concatenate([prev_key, key], axis=time_axis) value = np.concatenate([prev_value, value], axis=time_axis) new_states = np.stack([key, value], axis=0) # gen mask query_pos = npx.arange_like(query, axis=time_axis) if prev_len is not None: query_pos = query_pos + prev_len key_pos = npx.arange_like(key, axis=time_axis) # (query_len, key_len) mask = (npx.reshape(key_pos, (1, -1)) <= npx.reshape(query_pos, (-1, 1))).astype( self._dtype) # broadcast to (batch_size, query_len, key_len) mask = npx.broadcast_like(np.expand_dims(mask, axis=0), query, lhs_axes=0, rhs_axes=batch_axis) query = npx.reshape(query, (-2, -2, self._num_heads, -1)) key = npx.reshape(key, (-2, -2, self._num_heads, -1)) value = npx.reshape(value, (-2, -2, self._num_heads, -1)) out, [_, attn_weight] = self.attention_cell(query, key, value, mask) out = self.out_proj(out) out = self.hidden_dropout(out) return out, new_states
def prepare_source_valid_lengths(valid_length: np.ndarray, query_data: np.ndarray, num_heads: int) -> np.ndarray: """ Returns an int32 valid length tensor of shape (batch * num_heads, query_length) to be used in the softmax operation in DotAttentionCell with the length argument. Due to broadcast_like, dtypes of valid_length and query_data must be the same. :param valid_length: Valid length information. Shape: (batch,). :param query_data: Tensor from which the query_length dimension is derived. Expected shape: (X, query_length, ...). :param num_heads: Number of attention heads. :return: int32 tensor of shape (batch * num_heads, query_length). """ # (batch * heads,) att_valid_length = np.repeat(valid_length, repeats=num_heads, axis=0) att_valid_length = npx.broadcast_like(np.expand_dims(att_valid_length, axis=1), query_data, lhs_axes=(1, ), rhs_axes=(1, )) return att_valid_length.astype(dtype='int32', copy=False)