Ejemplo n.º 1
0
    def get_answerable_logits(self, contextual_embedding, p_mask):
        """Get the answerable logits.

        Parameters
        ----------
        contextual_embedding
            Shape (batch_size, sequence_length, C)
        p_mask
            Shape (batch_size, sequence_length)
            Mask the sequence.
            0 --> Denote that the element is masked,
            1 --> Denote that the element is not masked

        Returns
        -------
        answerable_logits
            Shape (batch_size, 2)
        """
        # Shape (batch_size, sequence_length)
        start_scores = np.squeeze(self.start_scores(contextual_embedding), -1)
        start_score_weights = masked_softmax(start_scores, p_mask, axis=-1)
        start_agg_feature = npx.batch_dot(np.expand_dims(start_score_weights, axis=1),
                                          contextual_embedding)
        start_agg_feature = np.squeeze(start_agg_feature, 1)
        cls_feature = contextual_embedding[:, 0, :]
        answerable_scores = self.answerable_scores(np.concatenate([start_agg_feature,
                                                                  cls_feature], axis=-1))
        answerable_logits = npx.log_softmax(answerable_scores, axis=-1)
        return answerable_logits
Ejemplo n.º 2
0
 def forward(self, A, B):
     # Shape of `A`/`B`: (b`atch_size`, no. of words in sequence A/B,
     # `embed_size`)
     # Shape of `f_A`/`f_B`: (`batch_size`, no. of words in sequence A/B,
     # `num_hiddens`)
     f_A = self.f(A)
     f_B = self.f(B)
     # Shape of `e`: (`batch_size`, no. of words in sequence A,
     # no. of words in sequence B)
     e = npx.batch_dot(f_A, f_B, transpose_b=True)
     # Shape of `beta`: (`batch_size`, no. of words in sequence A,
     # `embed_size`), where sequence B is softly aligned with each word
     # (axis 1 of `beta`) in sequence A
     beta = npx.batch_dot(npx.softmax(e), B)
     # Shape of `alpha`: (`batch_size`, no. of words in sequence B,
     # `embed_size`), where sequence A is softly aligned with each word
     # (axis 1 of `alpha`) in sequence B
     alpha = npx.batch_dot(npx.softmax(e.transpose(0, 2, 1)), A)
     return beta, alpha
Ejemplo n.º 3
0
 def forward(self, queries, keys, values, valid_lens):
     queries, keys = self.W_q(queries), self.W_k(keys)
     # After dimension expansion, shape of `queries`: (`batch_size`, no. of
     # queries, 1, `num_hiddens`) and shape of `keys`: (`batch_size`, 1,
     # no. of key-value pairs, `num_hiddens`). Sum them up with
     # broadcasting
     features = np.expand_dims(queries, axis=2) + np.expand_dims(
         keys, axis=1)
     features = np.tanh(features)
     # There is only one output of `self.w_v`, so we remove the last
     # one-dimensional entry from the shape. Shape of `scores`:
     # (`batch_size`, no. of queries, no. of key-value pairs)
     scores = np.squeeze(self.w_v(features), axis=-1)
     self.attention_weights = masked_softmax(scores, valid_lens)
     # Shape of `values`: (`batch_size`, no. of key-value pairs, value
     # dimension)
     return npx.batch_dot(self.dropout(self.attention_weights), values)
Ejemplo n.º 4
0
 def forward(self, query, key, value, valid_len=None):
     d = query.shape[-1]
     # Set transpose_b=True to swap the last two dimensions of key
     scores = npx.batch_dot(query, key, transpose_b=True) / math.sqrt(d)
     attention_weights = self.dropout(masked_softmax(scores, valid_len))
     return npx.batch_dot(attention_weights, value)
Ejemplo n.º 5
0
def multi_head_dot_attn(query, key, value,
                        mask=None,
                        edge_scores=None,
                        dropout: float = 0.0,
                        scaled: bool = True, normalized: bool = False,
                        eps: float = 1E-6, query_head_units: Optional[int] = None,
                        layout: str = 'NKT',
                        use_einsum: bool = False):
    """Multihead dot product attention between the query, key, value.

    scaled is False, normalized is False:
        D(h_q, h_k) = <h_q, h_k>
    scaled is True, normalized is False:
        D(h_q, h_k) = <h_q, h_k> / sqrt(dim_q)
    scaled is False, normalized is True:
        D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||>
    scaled is True, normalized is True:
        D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||> / sqrt(dim_q)

    If edge_scores is provided, we will calcualte the attention as
        scores = D(h_q, h_k) + EdgeScore_{q, k}

    Parameters
    ----------
    query
        Query. The shape depends on the layout

        - layout is 'NKT'
            Shape (batch_size, num_heads, query_length, key_dim)
        - layout is 'NTK'
            Shape (batch_size, query_length, num_heads, key_dim)
        - layout is 'TNK'
            Shape (query_length, batch_size, num_heads, key_dim)

    key
        Key. The shape depends on the layout

        - layout is 'NKT'
            Shape (batch_size, num_heads, mem_length, key_dim)
        - layout is 'NTK'
            Shape (batch_size, mem_length, num_heads, key_dim)
        - layout is 'TNK'
            Shape (mem_length, batch_size, num_heads, key_dim)

    value
        Value. The shape depends on the layout

        - layout is 'NKT'
            Shape (batch_size, num_heads, mem_length, value_dim)
        - layout is 'NTK'
            Shape (batch_size, mem_length, num_heads, value_dim)
        - layout is 'TNK'
            Shape (mem_length, batch_size, num_heads, value_dim)

    mask
        Mask between query and memory. Shape (batch_size, query_length, mem_length)
    edge_scores
        The edge attention score. Shape can be any shape that is broadcastable to
        (batch_size, num_heads, query_length, mem_length)
    dropout
        Dropout rate
    scaled
        Whether to divide the attention weights by the sqrt of the query dimension.
        This is first proposed in "[NIPS2017] Attention is all you need."::

        .. code-block:: none

            score = <h_q, h_k> / sqrt(dim_q)

    normalized
        If turned on, the cosine distance is used, i.e::

        .. code-block:: none

            score = <h_q / ||h_q||, h_k / ||h_k||>

    eps
        The epsilon value used in L2 normalization
    query_head_units
        The units of each query head. If it's empty, we will estimate it via the
        shape_array of the query.
    layout
        This stands for the layout of the attention cell. The shape of the input/output will depend
        on the layout. Currently, we support 'NKT', 'NTK' and 'TNK' in which
        'N' means the batch_size, 'K' means the head, and 'T' means the length dimension.
    use_einsum
        Whether to use einsum for the computation

    Returns
    -------
    context_vec
        - layout is 'NKT' or 'NTK'
            Shape (batch_size, query_length, num_heads * value_units)
        - layout is 'TNK'
            Shape (query_length, batch_size, num_heads * value_units)

    additional_info
        scores:
            Shape (batch_size, num_head, query_length, mem_length)
        attn_weight:
            Shape (batch_size, num_head, query_length, mem_length)
    """
    # TODO(sxjscience) Profile layout
    if normalized:
        query = l2_normalize(query, axis=-1, eps=eps)
        key = l2_normalize(key, axis=-1, eps=eps)
    if scaled:
        if query_head_units is None:
            raise NotImplementedError('You will need to specify query_head_units!')
        else:
            scale = math.sqrt(query_head_units)
    else:
        scale = None
    if layout == 'NKT':
        # 1. Expand the dimension of the mask:
        #   (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
        if mask is not None:
            mask = np.expand_dims(mask, axis=1).astype(np.bool)
        # 2. Calculate the attention weights
        #   Score: (B, N, L_query, C_Q) X (B, N, L_mem, C_Q) --> (B, N, L_query, L_mem)
        scores = npx.batch_dot(query, key, transpose_b=True)
        if edge_scores is not None:
            scores = scores + edge_scores
        attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale)
        attn_weights = npx.dropout(attn_weights, p=dropout)
        # 3. Calculate the context vector
        # (B, N, L_query, L_mem) X (B, N, L_mem, C_V) --> (B, L_query, N * C_V)
        if use_einsum:
            context_vec = np.einsum('bnij,bnjc->binc', attn_weights, value)
        else:
            context_vec = npx.batch_dot(attn_weights, value).transpose((0, 2, 1, 3))
        context_vec = npx.reshape(context_vec, (-2, -2, -1))
    elif layout == 'NTK':
        # 1. Expand the dimension of the mask:
        #   (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
        if mask is not None:
            mask = np.expand_dims(mask, axis=1).astype(np.bool)
        # 2. Calculate the attention weights
        #   Score: (B, L_query, N, C_Q) X (B, L_mem, N, C_Q) --> (B, N, L_query, L_mem)
        if use_einsum:
            scores = np.einsum('binc,bjnc->bnij', query, key)
        else:
            scores = npx.batch_dot(np.swapaxes(query, 1, 2), np.swapaxes(key, 1, 2),
                                   transpose_b=True)
        if edge_scores is not None:
            scores = scores + edge_scores
        attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale)
        attn_weights = npx.dropout(attn_weights, p=dropout)
        # 3. Calculate the context vector
        # (B, N, L_query, L_mem) X (B, L_mem, N, C_V) --> (B, L_query, N * C_V)
        if use_einsum:
            context_vec = np.einsum('bnij,bjnc->binc', attn_weights, value)
        else:
            context_vec = npx.batch_dot(attn_weights,
                                          np.swapaxes(value, 1, 2)).transpose((0, 2, 1, 3))
        context_vec = npx.reshape(context_vec, (-2, -2, -1))
    elif layout == 'TNK':
        # 1. Expand the dimension of the mask:
        #   (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
        if mask is not None:
            mask = np.expand_dims(mask, axis=1).astype(np.bool)
        # 2. Calculate the attention weights
        #   Score: (L_query, B, N, C_Q) X (L_mem, B, N, C_Q) --> (B, N, L_query, L_mem)
        #   This layout structure can be implemented very efficiently because B, N are consecutive
        #   to each other. To have a clear picture of what's happening, we may consider the
        #   (i, j)th element of the output
        #       out[i, j, :, :] = query[:, i, j, :] X key[:, i, j, :].T, which is just one GEMM call
        #   We can thus implement the whole kernel via a single call of batched GEMM with stride.
        if use_einsum:
            scores = np.einsum('ibnc,jbnc->bnij', query, key)
        else:
            scores = npx.batch_dot(query.transpose((1, 2, 0, 3)),
                                     key.transpose((1, 2, 3, 0)))
        if edge_scores is not None:
            scores = scores + edge_scores
        attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale)
        attn_weights = npx.dropout(attn_weights, p=dropout)
        # 3. Calculate the context vector
        # (B, N, L_query, L_mem) X (L_mem, B, N, C_V) --> (L_query, B, N * C_V)
        # Again, we can implement it via a single call to batched GEMM with stride.

        # Shape (B, N, L_query, C_V)
        if use_einsum:
            context_vec = np.einsum('bnij,jbnc->ibnc', attn_weights, value)
        else:
            context_vec = npx.batch_dot(attn_weights,
                                          value.transpose((1, 2, 0, 3))).transpose((2, 0, 1, 3))
        context_vec = npx.reshape(context_vec, (-2, -2, -1))
    else:
        raise NotImplementedError('layout="{}" is not supported! '
                                  'We only support layout = "NKT", "NTK", and "TNK".'
                                  .format(layout))
    return context_vec, [scores, attn_weights]
 def forward(self, query, key, value, valid_len=None):
     d = query.shape[-1]  # dimension
     # Set transpose_b=True to swap the last two dimensions of key
     scores = npx.batch_dot(query, key, transpose_b=True) / math.sqrt(d)  # check the reference (http://www.peterbloem.nl/blog/transformers_ - Why k−−√? Imagine a vector in ℝk with values all c. Its Euclidean length is k−−√c. Therefore, we are dividing out the amount by which the increase in dimension increases the length of the average vectors.
     attention_weights = self.dropout(masked_softmax(scores, valid_len))
     return npx.batch_dot(attention_weights, value)
Ejemplo n.º 7
0
 def forward(self, queries, keys, values, valid_lens=None):
     d = queries.shape[-1]
     # Set `transpose_b=True` to swap the last two dimensions of `keys`
     scores = npx.batch_dot(queries, keys, transpose_b=True) / math.sqrt(d)
     self.attention_weights = masked_softmax(scores, valid_lens)
     return npx.batch_dot(self.dropout(self.attention_weights), values)
Ejemplo n.º 8
0
def dot_attn_score(query,
                   key,
                   scaled=True,
                   normalized=False,
                   eps=1E-6,
                   layout='NT'):
    """The inner function call to calculate the score used in dot-product attention.

    We support multiple leading batch dimensions.

    scaled is True:
        D(h_q, h_k) = <h_q, h_k> / sqrt(dim_q)

    normalized is True:
            D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||>

    both scaled and normalized:
            D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||> / sqrt(dim_q)

    Parameters
    ----------
    query : symbol or ndarray
        - layout is 'NT'
            (B0, ..., BN, query_length, query_dim)
        - layout is 'TN'
            (query_length, B0, ..., BN, query_dim)
    key : symbol or ndarray
        - layout is 'NT'
            (B0, ..., BN, key_length, key_dim)
        - layout is 'TN'
            (key_length, B0, ..., BN, key_dim)
    scaled : bool
        Whether to divide the query by the square-root of the query_dim
        If True: D(h_q, h_k) = <h_q, h_k> / sqrt(dim_q)
    normalized : bool
        Whether to normalize the query and the key embeddings
        If True: D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||>
    eps : float
        The epsilon used in the normalization
    layout
        The layout of the layer. Can be 'TN' or 'NT'.

    Returns
    -------
    scores : symbol or ndarray
        (B0, ..., BN, query_length, key_length)
    """
    if normalized:
        query = l2_normalize(query, -1, eps=eps)
        key = l2_normalize(key, -1, eps=eps)
    if scaled:
        query_shape = npx.shape_array(query)
        # TODO(sxjscience) Remove .astype(np.float32).
        #  Wait for https://github.com/apache/incubator-mxnet/issues/18084
        query_units = query_shape[-1].astype(np.float32)
        query = query / np.sqrt(query_units)
    if layout == 'NT':
        scores = npx.batch_dot(query, key, transpose_b=True)
    else:
        raise NotImplementedError(
            'layout={} is not supported.'
            ' Currently, only layout = "NT" is implemented!'.format(layout))
    return scores