Exemple #1
0
    def forward(self, x, layer_states):
        """

        Parameters
        ----------
        x
            - layout = 'NT'
                Shape (batch_size, seq_length, C_in)
            - layout = 'TN'
                Shape (seq_length, batch_size, C_in)

        layer_states
            - layout = 'NT'
                Shape (2, batch_size, prev_len, C_in)
            - layout = 'TN'
                Shape (2, prev_len, batch_size, C_in)
        """
        x = self.ln(x)
        if self._layout == 'NT':
            batch_axis, time_axis = 0, 1
            prev_len = npx.shape_array(layer_states)[2]
        else:
            batch_axis, time_axis = 1, 0
            prev_len = npx.shape_array(layer_states)[1]

        query, key, value = np.split(self.qkv(x), 3, axis=-1)
        if layer_states is not None:
            prev_key, prev_value = layer_states[0], layer_states[1]
            key = np.concatenate([prev_key, key], axis=time_axis)
            value = np.concatenate([prev_value, value], axis=time_axis)
        new_states = np.stack([key, value], axis=0)

        # gen mask
        query_pos = npx.arange_like(query, axis=time_axis)
        if prev_len is not None:
            query_pos = query_pos + prev_len
        key_pos = npx.arange_like(key, axis=time_axis)
        # (query_len, key_len)
        mask = (npx.reshape(key_pos,
                            (1, -1)) <= npx.reshape(query_pos,
                                                    (-1, 1))).astype(
                                                        self._dtype)
        # broadcast to (batch_size, query_len, key_len)
        mask = npx.broadcast_like(np.expand_dims(mask, axis=0),
                                  query,
                                  lhs_axes=0,
                                  rhs_axes=batch_axis)

        query = npx.reshape(query, (-2, -2, self._num_heads, -1))
        key = npx.reshape(key, (-2, -2, self._num_heads, -1))
        value = npx.reshape(value, (-2, -2, self._num_heads, -1))

        out, [_, attn_weight] = self.attention_cell(query, key, value, mask)
        out = self.out_proj(out)
        out = self.hidden_dropout(out)

        return out, new_states
def test_reshape():
    A = np.ones((INT_OVERFLOW, 2))
    A.attach_grad()
    with mx.autograd.record():
        B = npx.reshape(A, (-5))
    assert B.shape == (DOUBLE_INT_OVERFLOW, )
    assert B[0] == 1
    B.backward()
    assert A.grad.shape == (INT_OVERFLOW, 2)
    assert A.grad[0][0] == 1
Exemple #3
0
 def forward(self, step_data, past_states):
     mem_states, mem_valid_length, position, past_key_values = past_states
     step_hidden_states = self.model.input_embedding_layer(step_data)
     # NT: (B, d_model) -> (B, 1, d_model); TN: (B, d_model) -> (1, B, d_model)
     step_hidden_states = np.expand_dims(step_hidden_states,
                                         axis=self.model._time_axis)
     step_hidden_states, present_key_values = self.model.decoder.incremental_decode(
         step_hidden_states, position, past_key_values, mem_states,
         mem_valid_length)
     step_hidden_states = self.output_layer(step_hidden_states)
     # NT: (B, 1, vocab_size) -> (B, vocab_size); TN: (1, B, vocab_size) -> (B, vocab_size)
     step_hidden_states = npx.reshape(step_hidden_states, (-5, -1))
     return step_hidden_states, (mem_states, mem_valid_length, position + 1,
                                 present_key_values)
def gen_rel_position(data, past_data=None, dtype=np.int32, layout='NT'): 
    """Create a matrix of relative position for RelAttentionScoreCell. 
    
    The relative position is defined as the index difference: `mem_i` - `query_j`. 
    Note, though, that the implementation here makes sense in self-attention's setting, 
    but not in cross-attention's. Hence, both `mem_i` and `query_j` are time indices from 
    `data` (or, in incremental decoding's case, the concatenated sequence from the current 
    stepwise `data` and the previous steps `past_data`). 

    Parameters
    ----------
    data
        The data. Under incremental decoding, seq_length = 1. 

        - layout = 'NT'
            Shape (batch_size, seq_length, C)
        - layout = 'TN'
            Shape (seq_length, batch_size, C)
    past_data
        This is only used under incremental decoding. Stacked data from previous steps. 
    dtype
        Data type of the mask
    layout
        Layout of the data + past_data

    Returns
    -------
    relative_position :
        Shape (query_length, mem_length) where query_length = mem_length = seq_length
    """
    time_axis = 1 if layout == 'NT' else 0
    if past_data is None: 
        position = npx.arange_like(data, axis=time_axis)
    else: 
        # for incremental decoding only, where past data is of the shape: 
        # NT(NTK): (B, L_seq, num_heads, n_kv) -> (B, L_seq, inner_dim)
        # TN(TNK): (L_seq, B, num_heads, n_kv) -> (L_seq, B, inner_dim)
        past_data = npx.reshape(past_data, (-2, -2, -5))
        position = npx.arange_like(
            np.concatenate([past_data, data], axis=time_axis), 
            axis=time_axis
        )
    query_position = np.expand_dims(position, axis=-1)
    mem_position = np.expand_dims(position, axis=0)
    relative_position = mem_position - query_position
    return relative_position.astype(np.int32) # shape (L_seq, L_seq)
Exemple #5
0
def add_vectors_by_position(data, increment, positions):
    """Scatter each batch with the given positions.

    data[i, positions[i, j], ...] += increment[i, j, ...]

    Parameters
    ----------
    F
    data
        Input tensor of the array to be updated.
        Shape (batch_size, seq_length, ...)
    increment
        Input tensor of token ids
        Shape (batch_size, num_disp_position, ...)
    positions
        Input tensor of the positions.
        Shape (batch_size, num_disp_position).
        For each sample in the batch, the values in this tensor must not exceed
        the length of the sequence.

    Returns
    -------
    out
        The updated result.
        Shape (batch_size, seq_length, ...)
    """
    # Here, we use index_add to disperse the output from data:
    # Need to compute
    #   out[i, masked_position[i, j], :] = in[i, j, :]
    # Thus, construct an indices with shape [2, batch_size * num_masked_position], where
    #     indices[0, i * num_masked_position + j] = i
    #     indices[1, i * num_masked_position + j] = masked_position[i, j]
    # And convert data to the shape of the (batch_size * num_masked_position, )
    # Then, out = npx.index_add(data, indices, increment)
    positions = positions.astype(np.int32)
    # batch_idx.shape = (batch_size, 1) as [[0], [1], [2], ...]
    batch_idx = np.expand_dims(npx.arange_like(positions, axis=0),
                               axis=1).astype(np.int32)
    batch_idx = batch_idx + np.zeros_like(positions)
    indices = np.stack([batch_idx.reshape((-1, )), positions.reshape((-1, ))])
    out = npx.index_add(data, indices, npx.reshape(increment, (-5, -4)))
    return out
Exemple #6
0
def update_vectors_by_position(data, val, positions):
    """
    Update each batch with the given positions. Considered as a reversed process of
    "select_vectors_by_position", this is an operator similar to "add_vectors_by_position"
    that updates the results instead of adding.

    data[i, positions[i, j], :] = val[i, j, :]

    Parameters
    ----------
    F
    data:
        Input tensor of the array to be updated.
        Shape (batch_size, seq_length)
    val
        Input tensor of token ids
        Shape (batch_size, num_disp_position)
    positions
        Input tensor of the positions.
        Shape (batch_size, num_disp_position).
        For each sample in the batch, the values in this tensor must not exceed
        the length of the sequence.

    Returns
    -------
    out
        The updated result.
        Shape (batch_size, seq_length)
    """
    positions = positions.astype(np.int32)
    # batch_idx.shape = (batch_size, 1) as [[0], [1], [2], ...]
    batch_idx = np.expand_dims(npx.arange_like(positions, axis=0),
                               axis=1).astype(np.int32)
    batch_idx = batch_idx + np.zeros_like(positions)
    indices = np.stack([batch_idx.reshape((-1, )), positions.reshape((-1, ))])

    out = npx.index_update(data, indices, npx.reshape(val, (-5, -4)))
    return out
Exemple #7
0
    def forward(self, rel_positions, query=None):
        """Forward function

        Parameters
        ----------
        rel_positions
            The relative shifts. Shape (query_length, mem_length).
            Each element represents the shift between the :math:`i-th` element of query and
            the :math:`j-th` element of memory.
        query
            The query for computing the relative scores. The shape depends on the layout.
            If we use T5 attention, the query will not be used.

        Returns
        -------
        rel_scores
            The relative attention scores
            Can have shape (batch_size, num_heads, query_length, mem_length)
            or (num_heads, query_length, mem_length)
        """
        if self._method == 'transformer_xl' or self._method == 'shaw':
            assert query is not None, 'Must specify query if method={}'.format(self._method)
            if self._bidirectional:
                if self._max_distance is not None:
                    rel_positions = np.clip(rel_positions,
                                              a_min=-self._max_distance, a_max=self._max_distance)
            else:
                if self._max_distance is not None:
                    rel_positions = np.clip(rel_positions,
                                              a_min=0, a_max=self._max_distance)
            # uniq_rel.shape = (#uniq,), rev_index.shape = (L_q, L_m)
            uniq_rel, rev_index = np.unique(rel_positions, return_inverse=True)

            uniq_rel_pos_embed = self._rel_pos_embed(uniq_rel)
            if self._method == 'transformer_xl':
                uniq_rel_pos_embed = self._rel_proj(self._dropout_layer(uniq_rel_pos_embed))
            # Shape (#uniq, K, C_q)
            uniq_rel_pos_embed = npx.reshape(uniq_rel_pos_embed,
                                               (-2, self._num_heads, self._head_query_units))
            # Calculate the dot-product between query and the relative positional embeddings.
            # After the calculation, rel_score.shape = (L_q, #uniq, N, K)
            if self._layout == 'NKT':
                # query_for_rel: (N, K, L_q, C_q)
                if self._use_einsum:
                    rel_score = np.einsum('bnid,jnd->ijbn', query, uniq_rel_pos_embed)
                else:
                    rel_score = np.transpose(
                        np.matmul(query,
                                    np.transpose(uniq_rel_pos_embed, (1, 2, 0))),
                        (2, 3, 0, 1)
                    )
            elif self._layout == 'NTK':
                # query_for_rel: (N, L_q, K, C_q)
                if self._use_einsum:
                    rel_score = np.einsum('bind,jnd->ijbn', query, uniq_rel_pos_embed)
                else:
                    rel_score = np.transpose(
                        np.matmul(np.swapaxes(query, 1, 2),
                                    np.transpose(uniq_rel_pos_embed, (1, 2, 0))),
                        (2, 3, 0, 1)
                    )
            elif self._layout == 'TNK':
                # query_for_rel: (L_q, N, K, C_q)
                if self._use_einsum:
                    rel_score = np.einsum('ibnd,jnd->ijbn', query, uniq_rel_pos_embed)
                else:
                    rel_score = np.transpose(
                        np.matmul(np.transpose(query, (1, 2, 0, 3)),
                                    np.transpose(uniq_rel_pos_embed, (1, 2, 0))),
                        (2, 3, 0, 1)
                    )
            else:
                raise NotImplementedError
            # We use gather_nd to select the elements
            # TODO(sxjscience) Use advanced indexing once available
            rev_index = npx.reshape_like(rev_index, rel_positions).astype(np.int32)
            query_idx = np.expand_dims(npx.arange_like(rel_positions, axis=0).astype(np.int32),
                                         axis=-1) + np.zeros_like(rev_index)
            rel_score = npx.gather_nd(rel_score, np.stack([query_idx, rev_index]))
            rel_score = np.transpose(rel_score, (2, 3, 0, 1))
        elif self._method == 't5':
            # shape is (K, L_q, L_m)
            rel_score = self._rel_pos_embed(rel_positions).transpose((2, 0, 1))
        else:
            raise NotImplementedError
        return rel_score
Exemple #8
0
def gen_self_attn_mask(data,
                       valid_length=None,
                       dtype: type = np.float32,
                       attn_type: str = 'full',
                       layout: str = 'NT'):
    """Generate the mask used for the encoder, i.e, self-attention.

    In our implementation, 1 --> not masked, 0 --> masked

    Let's consider the data with two samples:

    .. code-block:: none

        data =
            [['I',   'can', 'now',   'use', 'numpy', 'in',  'Gluon@@', 'NLP'  ],
             ['May', 'the', 'force', 'be',  'with',  'you', '<PAD>',   '<PAD>']]
        valid_length =
            [8, 6]

    - attn_type = 'causal'
        Each token will attend to itself + the tokens before.
        It will not attend to tokens in the future.

        For our example, the mask of the first sample is

        .. code-block:: none

                       ['I', 'can', 'now', 'use', 'numpy', 'in', 'Gluon@@', 'NLP']
            'I':         1,    0,     0,     0,      0,     0,      0,      0
            'can':       1,    1,     0,     0,      0,     0,      0,      0
            'now':       1,    1,     1,     0,      0,     0,      0,      0
            'use':       1,    1,     1,     1,      0,     0,      0,      0
            'numpy':     1,    1,     1,     1,      1,     0,      0,      0
            'in':        1,    1,     1,     1,      1,     1,      0,      0
            'Gluon@@':   1,    1,     1,     1,      1,     1,      1,      0
            'NLP':       1,    1,     1,     1,      1,     1,      1,      1

        The mask of the second sample is

        .. code-block:: none

                       ['May', 'the', 'force', 'be', 'with', 'you', '<PAD>', '<PAD>']
            'May':        1,    0,     0,     0,      0,     0,      0,      0
            'the':        1,    1,     0,     0,      0,     0,      0,      0
            'force':      1,    1,     1,     0,      0,     0,      0,      0
            'be':         1,    1,     1,     1,      0,     0,      0,      0
            'with':       1,    1,     1,     1,      1,     0,      0,      0
            'you':        1,    1,     1,     1,      1,     1,      0,      0
            '<PAD>':      0,    0,     0,     0,      0,     0,      0,      0
            '<PAD>':      0,    0,     0,     0,      0,     0,      0,      0


    - attn_type = 'full'
        Each token will attend to both the tokens before and in the future

        For our example, the mask of the first sample is

        .. code-block:: none

                       ['I', 'can', 'now', 'use', 'numpy', 'in', 'Gluon@@', 'NLP']
            'I':         1,    1,     1,     1,      1,     1,      1,      1
            'can':       1,    1,     1,     1,      1,     1,      1,      1
            'now':       1,    1,     1,     1,      1,     1,      1,      1
            'use':       1,    1,     1,     1,      1,     1,      1,      1
            'numpy':     1,    1,     1,     1,      1,     1,      1,      1
            'in':        1,    1,     1,     1,      1,     1,      1,      1
            'Gluon@@':   1,    1,     1,     1,      1,     1,      1,      1
            'NLP':       1,    1,     1,     1,      1,     1,      1,      1

        The mask of the second sample is

        .. code-block:: none

                       ['May', 'the', 'force', 'be', 'with', 'you', '<PAD>', '<PAD>']
            'May':        1,    1,     1,     1,      1,     1,      0,      0
            'the':        1,    1,     1,     1,      1,     1,      0,      0
            'force':      1,    1,     1,     1,      1,     1,      0,      0
            'be':         1,    1,     1,     1,      1,     1,      0,      0
            'with':       1,    1,     1,     1,      1,     1,      0,      0
            'you':        1,    1,     1,     1,      1,     1,      0,      0
            '<PAD>':      0,    0,     0,     0,      0,     0,      0,      0
            '<PAD>':      0,    0,     0,     0,      0,     0,      0,      0

    Parameters
    ----------
    data
        The data.

        - layout = 'NT'
            Shape (batch_size, seq_length, C)
        - layout = 'TN'
            Shape (seq_length, batch_size, C)

    valid_length
        Shape (batch_size,)
    dtype
        Data type of the mask
    attn_type
        Can be 'full' or 'causal'
    layout
        The layout of the data

    Returns
    -------
    mask
        Shape (batch_size, seq_length, seq_length)
    """
    if layout == 'NT':
        batch_axis, time_axis = 0, 1
    elif layout == 'TN':
        batch_axis, time_axis = 1, 0
    else:
        raise NotImplementedError('Unsupported layout={}'.format(layout))
    if attn_type == 'full':
        if valid_length is not None:
            valid_length = valid_length.astype(dtype)
            steps = npx.arange_like(data, axis=time_axis)  # (seq_length,)
            mask1 = (npx.reshape(steps, (1, 1, -1))
                     < npx.reshape(valid_length, (-2, 1, 1)))
            mask2 = (npx.reshape(steps, (1, -1, 1))
                     < npx.reshape(valid_length, (-2, 1, 1)))
            mask = mask1 * mask2
        else:
            # TODO(sxjscience) optimize
            seq_len_ones = np.ones_like(npx.arange_like(data, axis=time_axis))  # (seq_length,)
            batch_ones = np.ones_like(npx.arange_like(data, axis=batch_axis))   # (batch_size,)
            mask = batch_ones.reshape((-1, 1, 1)) * seq_len_ones.reshape((1, -1, 1))\
                   * seq_len_ones.reshape((1, 1, -1))
    elif attn_type == 'causal':
        steps = npx.arange_like(data, axis=time_axis)
        # mask: (seq_length, seq_length)
        # batch_mask: (batch_size, seq_length)
        mask = (np.expand_dims(steps, axis=0) <= np.expand_dims(steps, axis=1)).astype(dtype)
        if valid_length is not None:
            valid_length = valid_length.astype(dtype)
            batch_mask = (np.expand_dims(steps, axis=0) < np.expand_dims(valid_length, axis=-1)).astype(dtype)
            mask = mask * np.expand_dims(batch_mask, axis=-1)
        else:
            batch_ones = np.ones_like(npx.arange_like(data, axis=batch_axis),
                                        dtype=dtype)  # (batch_size,)
            mask = mask * batch_ones.reshape((-1, 1, 1))
    else:
        raise NotImplementedError
    return mask.astype(np.bool)
Exemple #9
0
def multi_head_dot_attn(query, key, value,
                        mask=None,
                        edge_scores=None,
                        dropout: float = 0.0,
                        scaled: bool = True, normalized: bool = False,
                        eps: float = 1E-6, query_head_units: Optional[int] = None,
                        layout: str = 'NKT',
                        use_einsum: bool = False):
    """Multihead dot product attention between the query, key, value.

    scaled is False, normalized is False:
        D(h_q, h_k) = <h_q, h_k>
    scaled is True, normalized is False:
        D(h_q, h_k) = <h_q, h_k> / sqrt(dim_q)
    scaled is False, normalized is True:
        D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||>
    scaled is True, normalized is True:
        D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||> / sqrt(dim_q)

    If edge_scores is provided, we will calcualte the attention as
        scores = D(h_q, h_k) + EdgeScore_{q, k}

    Parameters
    ----------
    query
        Query. The shape depends on the layout

        - layout is 'NKT'
            Shape (batch_size, num_heads, query_length, key_dim)
        - layout is 'NTK'
            Shape (batch_size, query_length, num_heads, key_dim)
        - layout is 'TNK'
            Shape (query_length, batch_size, num_heads, key_dim)

    key
        Key. The shape depends on the layout

        - layout is 'NKT'
            Shape (batch_size, num_heads, mem_length, key_dim)
        - layout is 'NTK'
            Shape (batch_size, mem_length, num_heads, key_dim)
        - layout is 'TNK'
            Shape (mem_length, batch_size, num_heads, key_dim)

    value
        Value. The shape depends on the layout

        - layout is 'NKT'
            Shape (batch_size, num_heads, mem_length, value_dim)
        - layout is 'NTK'
            Shape (batch_size, mem_length, num_heads, value_dim)
        - layout is 'TNK'
            Shape (mem_length, batch_size, num_heads, value_dim)

    mask
        Mask between query and memory. Shape (batch_size, query_length, mem_length)
    edge_scores
        The edge attention score. Shape can be any shape that is broadcastable to
        (batch_size, num_heads, query_length, mem_length)
    dropout
        Dropout rate
    scaled
        Whether to divide the attention weights by the sqrt of the query dimension.
        This is first proposed in "[NIPS2017] Attention is all you need."::

        .. code-block:: none

            score = <h_q, h_k> / sqrt(dim_q)

    normalized
        If turned on, the cosine distance is used, i.e::

        .. code-block:: none

            score = <h_q / ||h_q||, h_k / ||h_k||>

    eps
        The epsilon value used in L2 normalization
    query_head_units
        The units of each query head. If it's empty, we will estimate it via the
        shape_array of the query.
    layout
        This stands for the layout of the attention cell. The shape of the input/output will depend
        on the layout. Currently, we support 'NKT', 'NTK' and 'TNK' in which
        'N' means the batch_size, 'K' means the head, and 'T' means the length dimension.
    use_einsum
        Whether to use einsum for the computation

    Returns
    -------
    context_vec
        - layout is 'NKT' or 'NTK'
            Shape (batch_size, query_length, num_heads * value_units)
        - layout is 'TNK'
            Shape (query_length, batch_size, num_heads * value_units)

    additional_info
        scores:
            Shape (batch_size, num_head, query_length, mem_length)
        attn_weight:
            Shape (batch_size, num_head, query_length, mem_length)
    """
    # TODO(sxjscience) Profile layout
    if normalized:
        query = l2_normalize(query, axis=-1, eps=eps)
        key = l2_normalize(key, axis=-1, eps=eps)
    if scaled:
        if query_head_units is None:
            raise NotImplementedError('You will need to specify query_head_units!')
        else:
            scale = math.sqrt(query_head_units)
    else:
        scale = None
    if layout == 'NKT':
        # 1. Expand the dimension of the mask:
        #   (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
        if mask is not None:
            mask = np.expand_dims(mask, axis=1).astype(np.bool)
        # 2. Calculate the attention weights
        #   Score: (B, N, L_query, C_Q) X (B, N, L_mem, C_Q) --> (B, N, L_query, L_mem)
        scores = npx.batch_dot(query, key, transpose_b=True)
        if edge_scores is not None:
            scores = scores + edge_scores
        attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale)
        attn_weights = npx.dropout(attn_weights, p=dropout)
        # 3. Calculate the context vector
        # (B, N, L_query, L_mem) X (B, N, L_mem, C_V) --> (B, L_query, N * C_V)
        if use_einsum:
            context_vec = np.einsum('bnij,bnjc->binc', attn_weights, value)
        else:
            context_vec = npx.batch_dot(attn_weights, value).transpose((0, 2, 1, 3))
        context_vec = npx.reshape(context_vec, (-2, -2, -1))
    elif layout == 'NTK':
        # 1. Expand the dimension of the mask:
        #   (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
        if mask is not None:
            mask = np.expand_dims(mask, axis=1).astype(np.bool)
        # 2. Calculate the attention weights
        #   Score: (B, L_query, N, C_Q) X (B, L_mem, N, C_Q) --> (B, N, L_query, L_mem)
        if use_einsum:
            scores = np.einsum('binc,bjnc->bnij', query, key)
        else:
            scores = npx.batch_dot(np.swapaxes(query, 1, 2), np.swapaxes(key, 1, 2),
                                   transpose_b=True)
        if edge_scores is not None:
            scores = scores + edge_scores
        attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale)
        attn_weights = npx.dropout(attn_weights, p=dropout)
        # 3. Calculate the context vector
        # (B, N, L_query, L_mem) X (B, L_mem, N, C_V) --> (B, L_query, N * C_V)
        if use_einsum:
            context_vec = np.einsum('bnij,bjnc->binc', attn_weights, value)
        else:
            context_vec = npx.batch_dot(attn_weights,
                                          np.swapaxes(value, 1, 2)).transpose((0, 2, 1, 3))
        context_vec = npx.reshape(context_vec, (-2, -2, -1))
    elif layout == 'TNK':
        # 1. Expand the dimension of the mask:
        #   (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
        if mask is not None:
            mask = np.expand_dims(mask, axis=1).astype(np.bool)
        # 2. Calculate the attention weights
        #   Score: (L_query, B, N, C_Q) X (L_mem, B, N, C_Q) --> (B, N, L_query, L_mem)
        #   This layout structure can be implemented very efficiently because B, N are consecutive
        #   to each other. To have a clear picture of what's happening, we may consider the
        #   (i, j)th element of the output
        #       out[i, j, :, :] = query[:, i, j, :] X key[:, i, j, :].T, which is just one GEMM call
        #   We can thus implement the whole kernel via a single call of batched GEMM with stride.
        if use_einsum:
            scores = np.einsum('ibnc,jbnc->bnij', query, key)
        else:
            scores = npx.batch_dot(query.transpose((1, 2, 0, 3)),
                                     key.transpose((1, 2, 3, 0)))
        if edge_scores is not None:
            scores = scores + edge_scores
        attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale)
        attn_weights = npx.dropout(attn_weights, p=dropout)
        # 3. Calculate the context vector
        # (B, N, L_query, L_mem) X (L_mem, B, N, C_V) --> (L_query, B, N * C_V)
        # Again, we can implement it via a single call to batched GEMM with stride.

        # Shape (B, N, L_query, C_V)
        if use_einsum:
            context_vec = np.einsum('bnij,jbnc->ibnc', attn_weights, value)
        else:
            context_vec = npx.batch_dot(attn_weights,
                                          value.transpose((1, 2, 0, 3))).transpose((2, 0, 1, 3))
        context_vec = npx.reshape(context_vec, (-2, -2, -1))
    else:
        raise NotImplementedError('layout="{}" is not supported! '
                                  'We only support layout = "NKT", "NTK", and "TNK".'
                                  .format(layout))
    return context_vec, [scores, attn_weights]
Exemple #10
0
def gen_mem_attn_mask(mem, mem_valid_length, data, data_valid_length=None,
                      dtype=np.float32, layout: str = 'NT'):
    """Generate the mask used for the decoder. All query slots are attended to the memory slots.

    In our implementation, 1 --> not masked, 0 --> masked

    Let's consider the data + mem with a batch of two samples:

    .. code-block:: none

        mem = [['I',   'can', 'now',   'use'],
               ['May', 'the', 'force', '<PAD>']]
        mem_valid_length =
            [4, 3]
        data =
            [['numpy', 'in',    'Gluon@@', 'NLP'  ],
             ['be',    'with',  'you',     '<PAD>']]
        data_valid_length =
            [4, 3]

    For our example, the mask of the first sample is

    .. code-block:: none

                   ['I', 'can', 'now', 'use']
        'numpy':     1,    1,     1,     1
        'in':        1,    1,     1,     1
        'Gluon@@':   1,    1,     1,     1
        'NLP':       1,    1,     1,     1

    The mask of the second sample is

    .. code-block:: none

                   ['be', 'with', 'you', '<PAD>']
        'May':        1,    1,     1,     0
        'the':        1,    1,     1,     0
        'force':      1,    1,     1,     0
        '<PAD>':      0,    0,     0,     0


    Parameters
    ----------
    mem
       - layout = 'NT'
            Shape (batch_size, mem_length, C_mem)
       - layout = 'TN'
            Shape (mem_length, batch_size, C_mem)

    mem_valid_length :
        Shape (batch_size,)
    data
        - layout = 'NT'
            Shape (batch_size, query_length, C_data)
        - layout = 'TN'
            Shape (query_length, batch_size, C_data)

    data_valid_length :
        Shape (batch_size,)
    dtype
        Data type of the mask
    layout
        Layout of the data + mem tensor

    Returns
    -------
    mask :
        Shape (batch_size, query_length, mem_length)
    """
    if layout == 'NT':
        batch_axis, time_axis = 0, 1
    elif layout == 'TN':
        batch_axis, time_axis = 1, 0
    else:
        raise NotImplementedError('Unsupported layout={}'.format(layout))
    mem_valid_length = mem_valid_length.astype(dtype)
    mem_steps = npx.arange_like(mem, axis=time_axis)  # (mem_length,)
    data_steps = npx.arange_like(data, axis=time_axis)  # (query_length,)
    mem_mask = (npx.reshape(mem_steps, (1, 1, -1))
                < npx.reshape(mem_valid_length, (-2, 1, 1))).astype(dtype)  # (B, 1, mem_length)
    if data_valid_length is not None:
        data_valid_length = data_valid_length.astype(dtype)
        data_mask = (npx.reshape(data_steps, (1, -1, 1))
                     < npx.reshape(data_valid_length, (-2, 1, 1))).astype(dtype)  # (B, query_length, 1)
        mask = mem_mask * data_mask
    else:
        query_length_ones = np.ones_like(data_steps)
        mask = query_length_ones.reshape((1, -1, 1)) * mem_mask
    return mask.astype(np.bool)
Exemple #11
0
 def transpose_for_scores(self, x):
     # NT -> NTK: (B, L_seq, inner_dim) -> (B, L_seq, num_heads, n_kv)
     # TN -> TNK: (L_seq, B, inner_dim) -> (L_seq, B, num_heads, n_kv)
     return npx.reshape(x, (-2, -2, self._num_heads, -1))
    def forward(self, data, mem, rel_positions, mask, query_r_bias,
                query_k_bias):
        """

        Parameters
        ----------
        F
        data
            The input data.
            layout = 'NT':
                Shape (batch_size, query_length, units)
            layout = 'TN':
                Shape (query_length, batch_size, units)
        mem
            The memory.
            layout = 'NT':
                Shape (batch_size, mem_length, units)
            layout = 'TN':
                Shape (mem_length, batch_size, units)
        rel_positions
            The relative positions between data and [mem, data]
            Shape (query_length, mem_length + query_length).
            A positive value means that query is after the memory, i.e.,
            query_location - mem_location.
        mask
            Mask between the query and the memory + query.
            1--> will be used, 0 --> won't be used
            Shape (batch_size, query_length, mem_length + query_length)
        query_r_bias
            The query bias for calculating the relative scores
            Shape (num_heads, query_head_units)
        query_k_bias
            The key bias for calculating the relative scores.
            Shape (num_heads, query_head_units)

        Returns
        -------
        out
            - layout = 'NT'
                Shape (batch_size, query_length, units)
            - layout = 'TN'
                Shape (query_length, batch_size, units)
        """
        if self._layout == 'NT':
            context = np.concatenate([mem, data], axis=1)
        elif self._layout == 'TN':
            context = np.concatenate([mem, data], axis=0)
        else:
            raise NotImplementedError
        if self._pre_norm:
            query = self.attn_query(self.layer_norm(data))
            key_value = self.attn_kv(self.layer_norm(context))
            key, value = np.split(key_value, 2, axis=-1)
        else:
            query = self.attn_query(data)
            key_value = self.attn_kv(context)
            key, value = np.split(key_value, 2, axis=-1)
        query = npx.reshape(query, (-2, -2, self._num_heads, -1))
        key = npx.reshape(key, (-2, -2, self._num_heads, -1))
        value = npx.reshape(value, (-2, -2, self._num_heads, -1))
        # Compute attention
        rel_score = self.rel_pos_score_cell(rel_positions,
                                            query + query_r_bias)
        out, _ = self.attn_cell(query + query_k_bias, key, value, mask,
                                rel_score)
        out = self.dropout_layer(out)
        if self._pre_norm:
            out = data + out
        else:
            out = self.layer_norm(data + out)
        out = self.ffn(out)
        return out
 def forward(self, data, indices):
     mask = indices < 3
     data = npx.reshape(data, (-1, -2), reverse=True)
     mask = np.reshape(mask, (-1, ))
     sel = nd.np._internal.boolean_mask(data, mask)
     return sel
Exemple #14
0
    def forward(self, data, attn_mask):
        """

        Parameters
        ----------
        F
        data
            - layout = 'NT'
                Shape (batch_size, seq_length, C_in)
            - layout = 'TN'
                Shape (seq_length, batch_size, C_in)
        attn_mask
            The attention mask
            Shape (batch_size, seq_length, seq_length)

        Returns
        -------
        out
            - layout = 'NT'
                Shape (batch_size, seq_length, C_out)
            - layout = 'TN'
                Shape (seq_length, batch_size, C_out)
        attn_weight
            Shape (batch_size, seq_length, seq_length)
        """
        if self._use_bottleneck:
            bn_proj = self.in_bottleneck_proj(data)
            bn_proj = self.in_bottleneck_ln(bn_proj)
            input = bn_proj
            if self._bottleneck_strategy == 'qk_sharing':
                # for Mobile Bert
                qk_shared = self.shared_qk(data)
                qk_shared = self.shared_qk_ln(qk_shared)
                query = qk_shared
                key = qk_shared
                value = data
            elif self._bottleneck_strategy == 'from_bottleneck':
                # for Mobile Bert Tiny
                query = bn_proj
                key = bn_proj
                value = bn_proj
            elif self._bottleneck_strategy == 'from_input':
                query = data
                key = data
                value = data
            else:
                raise NotImplementedError
        else:
            input = data
            query = data
            key = data
            value = data

        query = npx.reshape(self.attn_query(query),
                            (-2, -2, self._num_heads, -1))
        key = npx.reshape(self.attn_key(key), (-2, -2, self._num_heads, -1))
        value = npx.reshape(self.attn_value(value),
                            (-2, -2, self._num_heads, -1))
        out, [_, attn_weight] = self.attention_cell(query, key, value,
                                                    attn_mask)
        out = self.attention_proj(out)
        if not self._use_bottleneck:
            out = self.dropout_layer(out)
        out = out + input
        out = self.layer_norm(out)
        for ffn_idx in range(self._num_stacked_ffn):
            ffn = self.stacked_ffn[ffn_idx]
            out = ffn(out)

        if self._use_bottleneck:
            out = self.out_bottleneck_proj(out)
            out = self.dropout_layer(out)
            out = out + data
            out = self.out_bottleneck_ln(out)
        return out, attn_weight
Exemple #15
0
 def transpose_for_scores(x):
     return npx.reshape(x, (-2, -2, self._num_heads, -1))