Exemple #1
0
    def get_end_logits(self, contextual_embedding, start_positions, p_mask):
        """

        Parameters
        ----------
        contextual_embedding
            Shape (batch_size, sequence_length, C)
        start_positions
            Shape (batch_size, N)
            We process multiple candidates simultaneously
        p_mask
            Shape (batch_size, sequence_length)

        Returns
        -------
        end_logits
            Shape (batch_size, N, sequence_length)
        """
        # Select the features at the start_positions
        # start_feature will have shape (batch_size, N, C)
        start_features = select_vectors_by_position(contextual_embedding, start_positions)
        # Concatenate the start_feature and the contextual_embedding
        contextual_embedding = np.expand_dims(contextual_embedding, axis=1)  # (B, 1, T, C)
        start_features = np.expand_dims(start_features, axis=2)  # (B, N, 1, C)
        concat_features = np.concatenate([npx.broadcast_like(start_features,
                                                                 contextual_embedding, 2, 2),
                                            npx.broadcast_like(contextual_embedding,
                                                                 start_features, 1, 1)],
                                           axis=-1)  # (B, N, T, 2C)
        end_scores = self.end_scores(concat_features)
        end_scores = np.squeeze(end_scores, -1)
        end_logits = masked_logsoftmax(end_scores, mask=np.expand_dims(p_mask, axis=1),
                                       axis=-1)
        return end_logits
def test_broadcast_like():
    A = np.ones((1, 2))
    B = np.zeros((INT_OVERFLOW, 2))
    A.attach_grad()
    with mx.autograd.record():
        C = npx.broadcast_like(A, B)
    assert C.shape == (INT_OVERFLOW, 2)
    assert C[0][0] == 1
    C.backward()
    assert A.grad.shape == (1, 2)
    with mx.autograd.record():
        C = npx.broadcast_like(A.reshape(2, 1), B.T)
    assert C.shape == (2, INT_OVERFLOW)
    assert C[0][0] == 1
    C.backward()
    assert A.grad.shape == (1, 2)
    assert_almost_equal(A.grad[0][0], np.array([INT_OVERFLOW]), \
                            rtol=1e-3, atol=1e-5)
Exemple #3
0
    def forward(self, x, layer_states):
        """

        Parameters
        ----------
        x
            - layout = 'NT'
                Shape (batch_size, seq_length, C_in)
            - layout = 'TN'
                Shape (seq_length, batch_size, C_in)

        layer_states
            - layout = 'NT'
                Shape (2, batch_size, prev_len, C_in)
            - layout = 'TN'
                Shape (2, prev_len, batch_size, C_in)
        """
        x = self.ln(x)
        if self._layout == 'NT':
            batch_axis, time_axis = 0, 1
            prev_len = npx.shape_array(layer_states)[2]
        else:
            batch_axis, time_axis = 1, 0
            prev_len = npx.shape_array(layer_states)[1]

        query, key, value = np.split(self.qkv(x), 3, axis=-1)
        if layer_states is not None:
            prev_key, prev_value = layer_states[0], layer_states[1]
            key = np.concatenate([prev_key, key], axis=time_axis)
            value = np.concatenate([prev_value, value], axis=time_axis)
        new_states = np.stack([key, value], axis=0)

        # gen mask
        query_pos = npx.arange_like(query, axis=time_axis)
        if prev_len is not None:
            query_pos = query_pos + prev_len
        key_pos = npx.arange_like(key, axis=time_axis)
        # (query_len, key_len)
        mask = (npx.reshape(key_pos,
                            (1, -1)) <= npx.reshape(query_pos,
                                                    (-1, 1))).astype(
                                                        self._dtype)
        # broadcast to (batch_size, query_len, key_len)
        mask = npx.broadcast_like(np.expand_dims(mask, axis=0),
                                  query,
                                  lhs_axes=0,
                                  rhs_axes=batch_axis)

        query = npx.reshape(query, (-2, -2, self._num_heads, -1))
        key = npx.reshape(key, (-2, -2, self._num_heads, -1))
        value = npx.reshape(value, (-2, -2, self._num_heads, -1))

        out, [_, attn_weight] = self.attention_cell(query, key, value, mask)
        out = self.out_proj(out)
        out = self.hidden_dropout(out)

        return out, new_states
Exemple #4
0
def prepare_source_valid_lengths(valid_length: np.ndarray,
                                 query_data: np.ndarray,
                                 num_heads: int) -> np.ndarray:
    """
    Returns an int32 valid length tensor of shape (batch * num_heads, query_length) to be used in
    the softmax operation in DotAttentionCell with the length argument.
    Due to broadcast_like, dtypes of valid_length and query_data must be the same.

    :param valid_length: Valid length information. Shape: (batch,).
    :param query_data: Tensor from which the query_length dimension is derived.
                       Expected shape: (X, query_length, ...).
    :param num_heads: Number of attention heads.
    :return: int32 tensor of shape (batch * num_heads, query_length).
    """
    # (batch * heads,)
    att_valid_length = np.repeat(valid_length, repeats=num_heads, axis=0)
    att_valid_length = npx.broadcast_like(np.expand_dims(att_valid_length,
                                                         axis=1),
                                          query_data,
                                          lhs_axes=(1, ),
                                          rhs_axes=(1, ))
    return att_valid_length.astype(dtype='int32', copy=False)