Exemplo n.º 1
0
    def forward(self, x, states):
        """

        Parameters
        ----------
        x
            - layout = 'NT'
                Shape (batch_size, seq_length)
            - layout = 'TN'
                Shape (seq_length, batch_size)

        states
            The previous states

            - layout = 'NT'
                Shape (num_layers, 2, batch_size, prev_len, C_in)]
            - layout = 'TN'
                Shape (num_layers, 2, prev_len, batch_size, C_in)]

        Returns
        -------
        new_x
            Output

            - layout = 'NT'
                Shape (batch_size, seq_length, C_out)
            - layout = 'TN'
                Shape (seq_length, batch_size, C_out)

        new_states
            The new states

            - layout = 'NT'
                Shape (num_layers, 2, batch_size, prev_len + seq_length, C_in)
            - layout = 'TN'
                Shape (num_layers, 2, prev_len + seq_length, batch_size, C_in)

        """
        prev_len = npx.shape_array(states)[3] if self._layout == 'NT' else \
                   npx.shape_array(states)[2]
        x = self.get_initial_embedding(x, prev_len)

        if self._layout != self._compute_layout:
            x = np.swapaxes(x, 0, 1)
            states = np.swapaxes(states, 2, 3)

        new_states = []
        for layer_idx in range(self._num_layers):
            layer_states = None if states is None else states[layer_idx]
            x, new_layer_states = self._layers[layer_idx](x, layer_states)
            new_states.append(new_layer_states)
        new_states = np.stack(new_states, axis=0)

        x = self._final_ln(x)
        if self._layout != self._compute_layout:
            x = np.swapaxes(x, 0, 1)
            new_states = np.swapaxes(new_states, 2, 3)
        return x, new_states
Exemplo n.º 2
0
    def forward(self, x, layer_states):
        """

        Parameters
        ----------
        x
            - layout = 'NT'
                Shape (batch_size, seq_length, C_in)
            - layout = 'TN'
                Shape (seq_length, batch_size, C_in)

        layer_states
            - layout = 'NT'
                Shape (2, batch_size, prev_len, C_in)
            - layout = 'TN'
                Shape (2, prev_len, batch_size, C_in)
        """
        x = self.ln(x)
        if self._layout == 'NT':
            batch_axis, time_axis = 0, 1
            prev_len = npx.shape_array(layer_states)[2]
        else:
            batch_axis, time_axis = 1, 0
            prev_len = npx.shape_array(layer_states)[1]

        query, key, value = np.split(self.qkv(x), 3, axis=-1)
        if layer_states is not None:
            prev_key, prev_value = layer_states[0], layer_states[1]
            key = np.concatenate([prev_key, key], axis=time_axis)
            value = np.concatenate([prev_value, value], axis=time_axis)
        new_states = np.stack([key, value], axis=0)

        # gen mask
        query_pos = npx.arange_like(query, axis=time_axis)
        if prev_len is not None:
            query_pos = query_pos + prev_len
        key_pos = npx.arange_like(key, axis=time_axis)
        # (query_len, key_len)
        mask = (npx.reshape(key_pos,
                            (1, -1)) <= npx.reshape(query_pos,
                                                    (-1, 1))).astype(
                                                        self._dtype)
        # broadcast to (batch_size, query_len, key_len)
        mask = npx.broadcast_like(np.expand_dims(mask, axis=0),
                                  query,
                                  lhs_axes=0,
                                  rhs_axes=batch_axis)

        query = npx.reshape(query, (-2, -2, self._num_heads, -1))
        key = npx.reshape(key, (-2, -2, self._num_heads, -1))
        value = npx.reshape(value, (-2, -2, self._num_heads, -1))

        out, [_, attn_weight] = self.attention_cell(query, key, value, mask)
        out = self.out_proj(out)
        out = self.hidden_dropout(out)

        return out, new_states
Exemplo n.º 3
0
def test_shape_array():
    A = np.zeros((INT_OVERFLOW, 2))
    A.attach_grad()
    with mx.autograd.record():
        B = npx.shape_array(A)
    assert B[0] == INT_OVERFLOW and B[1] == 2
    B.backward()
    assert A.grad.shape == (INT_OVERFLOW, 2)
    assert A.grad[0][0] == 0
Exemplo n.º 4
0
def multi_head_dot_attn(query,
                        key,
                        value,
                        mask=None,
                        edge_scores=None,
                        dropout: float = 0.0,
                        scaled: bool = True,
                        normalized: bool = False,
                        eps: float = 1E-6,
                        query_head_units: Optional[int] = None,
                        layout: str = 'NKT',
                        use_einsum: bool = False,
                        dtype=np.float32):
    """Multihead dot product attention between the query, key, value.

    scaled is False, normalized is False:
        D(h_q, h_k) = <h_q, h_k>
    scaled is True, normalized is False:
        D(h_q, h_k) = <h_q, h_k> / sqrt(dim_q)
    scaled is False, normalized is True:
        D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||>
    scaled is True, normalized is True:
        D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||> / sqrt(dim_q)

    If edge_scores is provided, we will calcualte the attention as
        scores = D(h_q, h_k) + EdgeScore_{q, k}

    Parameters
    ----------
    query
        Query. The shape depends on the layout
        - layout is 'NKT'
            Shape (batch_size, num_heads, query_length, key_dim)
        - layout is 'NTK'
            Shape (batch_size, query_length, num_heads, key_dim)
        - layout is 'TNK'
            Shape (query_length, batch_size, num_heads, key_dim)
    key
        Key. The shape depends on the layout
        - layout is 'NKT'
            Shape (batch_size, num_heads, mem_length, key_dim)
        - layout is 'NTK'
            Shape (batch_size, mem_length, num_heads, key_dim)
        - layout is 'TNK'
            Shape (mem_length, batch_size, num_heads, key_dim)
    value
        Value. The shape depends on the layout
        - layout is 'NKT'
            Shape (batch_size, num_heads, mem_length, value_dim)
        - layout is 'NTK'
            Shape (batch_size, mem_length, num_heads, value_dim)
        - layout is 'TNK'
            Shape (mem_length, batch_size, num_heads, value_dim)
    mask
        Mask between query and memory. Shape (batch_size, query_length, mem_length)
    edge_scores
        The edge attention score. Shape can be any shape that is broadcastable to
        (batch_size, num_heads, query_length, mem_length)
    dropout
        Dropout rate
    scaled
        Whether to divide the attention weights by the sqrt of the query dimension.
        This is first proposed in "[NIPS2017] Attention is all you need."::

            score = <h_q, h_k> / sqrt(dim_q)

    normalized
        If turned on, the cosine distance is used, i.e::

            score = <h_q / ||h_q||, h_k / ||h_k||>

    eps
        The epsilon value used in L2 normalization
    query_head_units
        The units of each query head. If it's empty, we will estimate it via the
        shape_array of the query.
    layout
        This stands for the layout of the attention cell. The shape of the input/output will depend
        on the layout. Currently, we support 'NKT', 'NTK' and 'TNK' in which
        'N' means the batch_size, 'K' means the head, and 'T' means the length dimension.
    use_einsum
        Whether to use einsum for the computation

    Returns
    -------
    context_vec
        - layout is 'NKT' or 'NTK'
            Shape (batch_size, query_length, num_heads * value_units)
        - layout is 'TNK'
            Shape (query_length, batch_size, num_heads * value_units)
    additional_info
        scores:
            Shape (batch_size, num_head, query_length, mem_length)
        attn_weight:
            Shape (batch_size, num_head, query_length, mem_length)
    """
    # TODO(sxjscience) Profile layout
    if normalized:
        query = l2_normalize(query, axis=-1, eps=eps)
        key = l2_normalize(key, axis=-1, eps=eps)
    if scaled:
        if query_head_units is None:
            query_shape = npx.shape_array(query)
            scale = np.sqrt(query_shape[-1])
        else:
            scale = math.sqrt(query_head_units)
    else:
        scale = None
    if layout == 'NKT':
        # 1. Expand the dimension of the mask:
        #   (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
        if mask is not None:
            mask = np.expand_dims(mask, axis=1)
        # 2. Calculate the attention weights
        #   Score: (B, N, L_query, C_Q) X (B, N, L_mem, C_Q) --> (B, N, L_query, L_mem)
        scores = npx.batch_dot(query, key, transpose_b=True)
        if edge_scores is not None:
            scores = scores + edge_scores
        if scaled:
            scores = scores / scale
        attn_weights = masked_softmax(scores, mask, dtype=dtype, axis=-1)
        attn_weights = npx.dropout(attn_weights, p=dropout)
        # 3. Calculate the context vector
        # (B, N, L_query, L_mem) X (B, N, L_mem, C_V) --> (B, L_query, N * C_V)
        if use_einsum:
            context_vec = np.einsum('bnij,bnjc->binc', attn_weights, value)
        else:
            context_vec = npx.batch_dot(attn_weights, value).transpose(
                (0, 2, 1, 3))
        context_vec = npx.reshape(context_vec, (-2, -2, -1))
    elif layout == 'NTK':
        # 1. Expand the dimension of the mask:
        #   (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
        if mask is not None:
            mask = np.expand_dims(mask, axis=1)
        # 2. Calculate the attention weights
        #   Score: (B, L_query, N, C_Q) X (B, L_mem, N, C_Q) --> (B, N, L_query, L_mem)
        if use_einsum:
            scores = np.einsum('binc,bjnc->bnij', query, key)
        else:
            scores = npx.batch_dot(np.swapaxes(query, 1, 2),
                                   np.swapaxes(key, 1, 2),
                                   transpose_b=True)
        if edge_scores is not None:
            scores = scores + edge_scores
        if scaled:
            scores = scores / scale
        attn_weights = masked_softmax(scores, mask, dtype=dtype)
        attn_weights = npx.dropout(attn_weights, p=dropout)
        # 3. Calculate the context vector
        # (B, N, L_query, L_mem) X (B, L_mem, N, C_V) --> (B, L_query, N * C_V)
        if use_einsum:
            context_vec = np.einsum('bnij,bjnc->binc', attn_weights, value)
        else:
            context_vec = npx.batch_dot(attn_weights,
                                        np.swapaxes(value, 1, 2)).transpose(
                                            (0, 2, 1, 3))
        context_vec = npx.reshape(context_vec, (-2, -2, -1))
    elif layout == 'TNK':
        # 1. Expand the dimension of the mask:
        #   (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
        if mask is not None:
            mask = np.expand_dims(mask, axis=1)
        # 2. Calculate the attention weights
        #   Score: (L_query, B, N, C_Q) X (L_mem, B, N, C_Q) --> (B, N, L_query, L_mem)
        #   This layout structure can be implemented very efficiently because B, N are consecutive
        #   to each other. To have a clear picture of what's happening, we may consider the
        #   (i, j)th element of the output
        #       out[i, j, :, :] = query[:, i, j, :] X key[:, i, j, :].T, which is just one GEMM call
        #   We can thus implement the whole kernel via a single call of batched GEMM with stride.
        if use_einsum:
            scores = np.einsum('ibnc,jbnc->bnij', query, key)
        else:
            scores = npx.batch_dot(query.transpose((1, 2, 0, 3)),
                                   key.transpose((1, 2, 3, 0)))
        if edge_scores is not None:
            scores = scores + edge_scores
        if scaled:
            scores = scores / scale
        attn_weights = masked_softmax(scores, mask, dtype=dtype)
        attn_weights = npx.dropout(attn_weights, p=dropout)
        # 3. Calculate the context vector
        # (B, N, L_query, L_mem) X (L_mem, B, N, C_V) --> (L_query, B, N * C_V)
        # Again, we can implement it via a single call to batched GEMM with stride.

        # Shape (B, N, L_query, C_V)
        if use_einsum:
            context_vec = np.einsum('bnij,jbnc->ibnc', attn_weights, value)
        else:
            context_vec = npx.batch_dot(attn_weights,
                                        value.transpose(
                                            (1, 2, 0, 3))).transpose(
                                                (2, 0, 1, 3))
        context_vec = npx.reshape(context_vec, (-2, -2, -1))
    else:
        raise NotImplementedError(
            'layout="{}" is not supported! '
            'We only support layout = "NKT", "NTK", and "TNK".'.format(layout))
    return context_vec, [scores, attn_weights]
Exemplo n.º 5
0
def dot_attn_score(query,
                   key,
                   scaled=True,
                   normalized=False,
                   eps=1E-6,
                   layout='NT'):
    """The inner function call to calculate the score used in dot-product attention.

    We support multiple leading batch dimensions.

    scaled is True:
        D(h_q, h_k) = <h_q, h_k> / sqrt(dim_q)

    normalized is True:
            D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||>

    both scaled and normalized:
            D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||> / sqrt(dim_q)

    Parameters
    ----------
    query : symbol or ndarray
        - layout is 'NT'
            (B0, ..., BN, query_length, query_dim)
        - layout is 'TN'
            (query_length, B0, ..., BN, query_dim)
    key : symbol or ndarray
        - layout is 'NT'
            (B0, ..., BN, key_length, key_dim)
        - layout is 'TN'
            (key_length, B0, ..., BN, key_dim)
    scaled : bool
        Whether to divide the query by the square-root of the query_dim
        If True: D(h_q, h_k) = <h_q, h_k> / sqrt(dim_q)
    normalized : bool
        Whether to normalize the query and the key embeddings
        If True: D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||>
    eps : float
        The epsilon used in the normalization
    layout
        The layout of the layer. Can be 'TN' or 'NT'.

    Returns
    -------
    scores : symbol or ndarray
        (B0, ..., BN, query_length, key_length)
    """
    if normalized:
        query = l2_normalize(query, -1, eps=eps)
        key = l2_normalize(key, -1, eps=eps)
    if scaled:
        query_shape = npx.shape_array(query)
        # TODO(sxjscience) Remove .astype(np.float32).
        #  Wait for https://github.com/apache/incubator-mxnet/issues/18084
        query_units = query_shape[-1].astype(np.float32)
        query = query / np.sqrt(query_units)
    if layout == 'NT':
        scores = npx.batch_dot(query, key, transpose_b=True)
    else:
        raise NotImplementedError(
            'layout={} is not supported.'
            ' Currently, only layout = "NT" is implemented!'.format(layout))
    return scores