Exemple #1
    def forward(self, x, layer_states):

            - layout = 'NT'
                Shape (batch_size, seq_length, C_in)
            - layout = 'TN'
                Shape (seq_length, batch_size, C_in)

            - layout = 'NT'
                Shape (2, batch_size, prev_len, C_in)
            - layout = 'TN'
                Shape (2, prev_len, batch_size, C_in)
        x = self.ln(x)
        if self._layout == 'NT':
            batch_axis, time_axis = 0, 1
            prev_len = npx.shape_array(layer_states)[2]
            batch_axis, time_axis = 1, 0
            prev_len = npx.shape_array(layer_states)[1]

        query, key, value = np.split(self.qkv(x), 3, axis=-1)
        if layer_states is not None:
            prev_key, prev_value = layer_states[0], layer_states[1]
            key = np.concatenate([prev_key, key], axis=time_axis)
            value = np.concatenate([prev_value, value], axis=time_axis)
        new_states = np.stack([key, value], axis=0)

        # gen mask
        query_pos = npx.arange_like(query, axis=time_axis)
        if prev_len is not None:
            query_pos = query_pos + prev_len
        key_pos = npx.arange_like(key, axis=time_axis)
        # (query_len, key_len)
        mask = (npx.reshape(key_pos,
                            (1, -1)) <= npx.reshape(query_pos,
                                                    (-1, 1))).astype(
        # broadcast to (batch_size, query_len, key_len)
        mask = npx.broadcast_like(np.expand_dims(mask, axis=0),

        query = npx.reshape(query, (-2, -2, self._num_heads, -1))
        key = npx.reshape(key, (-2, -2, self._num_heads, -1))
        value = npx.reshape(value, (-2, -2, self._num_heads, -1))

        out, [_, attn_weight] = self.attention_cell(query, key, value, mask)
        out = self.out_proj(out)
        out = self.hidden_dropout(out)

        return out, new_states
def test_reshape():
    A = np.ones((INT_OVERFLOW, 2))
    with mx.autograd.record():
        B = npx.reshape(A, (-5))
    assert B.shape == (DOUBLE_INT_OVERFLOW, )
    assert B[0] == 1
    assert A.grad.shape == (INT_OVERFLOW, 2)
    assert A.grad[0][0] == 1
Exemple #3
 def forward(self, step_data, past_states):
     mem_states, mem_valid_length, position, past_key_values = past_states
     step_hidden_states = self.model.input_embedding_layer(step_data)
     # NT: (B, d_model) -> (B, 1, d_model); TN: (B, d_model) -> (1, B, d_model)
     step_hidden_states = np.expand_dims(step_hidden_states,
     step_hidden_states, present_key_values = self.model.decoder.incremental_decode(
         step_hidden_states, position, past_key_values, mem_states,
     step_hidden_states = self.output_layer(step_hidden_states)
     # NT: (B, 1, vocab_size) -> (B, vocab_size); TN: (1, B, vocab_size) -> (B, vocab_size)
     step_hidden_states = npx.reshape(step_hidden_states, (-5, -1))
     return step_hidden_states, (mem_states, mem_valid_length, position + 1,
def gen_rel_position(data, past_data=None, dtype=np.int32, layout='NT'): 
    """Create a matrix of relative position for RelAttentionScoreCell. 
    The relative position is defined as the index difference: `mem_i` - `query_j`. 
    Note, though, that the implementation here makes sense in self-attention's setting, 
    but not in cross-attention's. Hence, both `mem_i` and `query_j` are time indices from 
    `data` (or, in incremental decoding's case, the concatenated sequence from the current 
    stepwise `data` and the previous steps `past_data`). 

        The data. Under incremental decoding, seq_length = 1. 

        - layout = 'NT'
            Shape (batch_size, seq_length, C)
        - layout = 'TN'
            Shape (seq_length, batch_size, C)
        This is only used under incremental decoding. Stacked data from previous steps. 
        Data type of the mask
        Layout of the data + past_data

    relative_position :
        Shape (query_length, mem_length) where query_length = mem_length = seq_length
    time_axis = 1 if layout == 'NT' else 0
    if past_data is None: 
        position = npx.arange_like(data, axis=time_axis)
        # for incremental decoding only, where past data is of the shape: 
        # NT(NTK): (B, L_seq, num_heads, n_kv) -> (B, L_seq, inner_dim)
        # TN(TNK): (L_seq, B, num_heads, n_kv) -> (L_seq, B, inner_dim)
        past_data = npx.reshape(past_data, (-2, -2, -5))
        position = npx.arange_like(
            np.concatenate([past_data, data], axis=time_axis), 
    query_position = np.expand_dims(position, axis=-1)
    mem_position = np.expand_dims(position, axis=0)
    relative_position = mem_position - query_position
    return relative_position.astype(np.int32) # shape (L_seq, L_seq)
Exemple #5
def add_vectors_by_position(data, increment, positions):
    """Scatter each batch with the given positions.

    data[i, positions[i, j], ...] += increment[i, j, ...]

        Input tensor of the array to be updated.
        Shape (batch_size, seq_length, ...)
        Input tensor of token ids
        Shape (batch_size, num_disp_position, ...)
        Input tensor of the positions.
        Shape (batch_size, num_disp_position).
        For each sample in the batch, the values in this tensor must not exceed
        the length of the sequence.

        The updated result.
        Shape (batch_size, seq_length, ...)
    # Here, we use index_add to disperse the output from data:
    # Need to compute
    #   out[i, masked_position[i, j], :] = in[i, j, :]
    # Thus, construct an indices with shape [2, batch_size * num_masked_position], where
    #     indices[0, i * num_masked_position + j] = i
    #     indices[1, i * num_masked_position + j] = masked_position[i, j]
    # And convert data to the shape of the (batch_size * num_masked_position, )
    # Then, out = npx.index_add(data, indices, increment)
    positions = positions.astype(np.int32)
    # batch_idx.shape = (batch_size, 1) as [[0], [1], [2], ...]
    batch_idx = np.expand_dims(npx.arange_like(positions, axis=0),
    batch_idx = batch_idx + np.zeros_like(positions)
    indices = np.stack([batch_idx.reshape((-1, )), positions.reshape((-1, ))])
    out = npx.index_add(data, indices, npx.reshape(increment, (-5, -4)))
    return out
Exemple #6
def update_vectors_by_position(data, val, positions):
    Update each batch with the given positions. Considered as a reversed process of
    "select_vectors_by_position", this is an operator similar to "add_vectors_by_position"
    that updates the results instead of adding.

    data[i, positions[i, j], :] = val[i, j, :]

        Input tensor of the array to be updated.
        Shape (batch_size, seq_length)
        Input tensor of token ids
        Shape (batch_size, num_disp_position)
        Input tensor of the positions.
        Shape (batch_size, num_disp_position).
        For each sample in the batch, the values in this tensor must not exceed
        the length of the sequence.

        The updated result.
        Shape (batch_size, seq_length)
    positions = positions.astype(np.int32)
    # batch_idx.shape = (batch_size, 1) as [[0], [1], [2], ...]
    batch_idx = np.expand_dims(npx.arange_like(positions, axis=0),
    batch_idx = batch_idx + np.zeros_like(positions)
    indices = np.stack([batch_idx.reshape((-1, )), positions.reshape((-1, ))])

    out = npx.index_update(data, indices, npx.reshape(val, (-5, -4)))
    return out
Exemple #7
    def forward(self, rel_positions, query=None):
        """Forward function

            The relative shifts. Shape (query_length, mem_length).
            Each element represents the shift between the :math:`i-th` element of query and
            the :math:`j-th` element of memory.
            The query for computing the relative scores. The shape depends on the layout.
            If we use T5 attention, the query will not be used.

            The relative attention scores
            Can have shape (batch_size, num_heads, query_length, mem_length)
            or (num_heads, query_length, mem_length)
        if self._method == 'transformer_xl' or self._method == 'shaw':
            assert query is not None, 'Must specify query if method={}'.format(self._method)
            if self._bidirectional:
                if self._max_distance is not None:
                    rel_positions = np.clip(rel_positions,
                                              a_min=-self._max_distance, a_max=self._max_distance)
                if self._max_distance is not None:
                    rel_positions = np.clip(rel_positions,
                                              a_min=0, a_max=self._max_distance)
            # uniq_rel.shape = (#uniq,), rev_index.shape = (L_q, L_m)
            uniq_rel, rev_index = np.unique(rel_positions, return_inverse=True)

            uniq_rel_pos_embed = self._rel_pos_embed(uniq_rel)
            if self._method == 'transformer_xl':
                uniq_rel_pos_embed = self._rel_proj(self._dropout_layer(uniq_rel_pos_embed))
            # Shape (#uniq, K, C_q)
            uniq_rel_pos_embed = npx.reshape(uniq_rel_pos_embed,
                                               (-2, self._num_heads, self._head_query_units))
            # Calculate the dot-product between query and the relative positional embeddings.
            # After the calculation, rel_score.shape = (L_q, #uniq, N, K)
            if self._layout == 'NKT':
                # query_for_rel: (N, K, L_q, C_q)
                if self._use_einsum:
                    rel_score = np.einsum('bnid,jnd->ijbn', query, uniq_rel_pos_embed)
                    rel_score = np.transpose(
                                    np.transpose(uniq_rel_pos_embed, (1, 2, 0))),
                        (2, 3, 0, 1)
            elif self._layout == 'NTK':
                # query_for_rel: (N, L_q, K, C_q)
                if self._use_einsum:
                    rel_score = np.einsum('bind,jnd->ijbn', query, uniq_rel_pos_embed)
                    rel_score = np.transpose(
                        np.matmul(np.swapaxes(query, 1, 2),
                                    np.transpose(uniq_rel_pos_embed, (1, 2, 0))),
                        (2, 3, 0, 1)
            elif self._layout == 'TNK':
                # query_for_rel: (L_q, N, K, C_q)
                if self._use_einsum:
                    rel_score = np.einsum('ibnd,jnd->ijbn', query, uniq_rel_pos_embed)
                    rel_score = np.transpose(
                        np.matmul(np.transpose(query, (1, 2, 0, 3)),
                                    np.transpose(uniq_rel_pos_embed, (1, 2, 0))),
                        (2, 3, 0, 1)
                raise NotImplementedError
            # We use gather_nd to select the elements
            # TODO(sxjscience) Use advanced indexing once available
            rev_index = npx.reshape_like(rev_index, rel_positions).astype(np.int32)
            query_idx = np.expand_dims(npx.arange_like(rel_positions, axis=0).astype(np.int32),
                                         axis=-1) + np.zeros_like(rev_index)
            rel_score = npx.gather_nd(rel_score, np.stack([query_idx, rev_index]))
            rel_score = np.transpose(rel_score, (2, 3, 0, 1))
        elif self._method == 't5':
            # shape is (K, L_q, L_m)
            rel_score = self._rel_pos_embed(rel_positions).transpose((2, 0, 1))
            raise NotImplementedError
        return rel_score
Exemple #8
def gen_self_attn_mask(data,
                       dtype: type = np.float32,
                       attn_type: str = 'full',
                       layout: str = 'NT'):
    """Generate the mask used for the encoder, i.e, self-attention.

    In our implementation, 1 --> not masked, 0 --> masked

    Let's consider the data with two samples:

    .. code-block:: none

        data =
            [['I',   'can', 'now',   'use', 'numpy', 'in',  'Gluon@@', 'NLP'  ],
             ['May', 'the', 'force', 'be',  'with',  'you', '<PAD>',   '<PAD>']]
        valid_length =
            [8, 6]

    - attn_type = 'causal'
        Each token will attend to itself + the tokens before.
        It will not attend to tokens in the future.

        For our example, the mask of the first sample is

        .. code-block:: none

                       ['I', 'can', 'now', 'use', 'numpy', 'in', 'Gluon@@', 'NLP']
            'I':         1,    0,     0,     0,      0,     0,      0,      0
            'can':       1,    1,     0,     0,      0,     0,      0,      0
            'now':       1,    1,     1,     0,      0,     0,      0,      0
            'use':       1,    1,     1,     1,      0,     0,      0,      0
            'numpy':     1,    1,     1,     1,      1,     0,      0,      0
            'in':        1,    1,     1,     1,      1,     1,      0,      0
            'Gluon@@':   1,    1,     1,     1,      1,     1,      1,      0
            'NLP':       1,    1,     1,     1,      1,     1,      1,      1

        The mask of the second sample is

        .. code-block:: none

                       ['May', 'the', 'force', 'be', 'with', 'you', '<PAD>', '<PAD>']
            'May':        1,    0,     0,     0,      0,     0,      0,      0
            'the':        1,    1,     0,     0,      0,     0,      0,      0
            'force':      1,    1,     1,     0,      0,     0,      0,      0
            'be':         1,    1,     1,     1,      0,     0,      0,      0
            'with':       1,    1,     1,     1,      1,     0,      0,      0
            'you':        1,    1,     1,     1,      1,     1,      0,      0
            '<PAD>':      0,    0,     0,     0,      0,     0,      0,      0
            '<PAD>':      0,    0,     0,     0,      0,     0,      0,      0

    - attn_type = 'full'
        Each token will attend to both the tokens before and in the future

        For our example, the mask of the first sample is

        .. code-block:: none

                       ['I', 'can', 'now', 'use', 'numpy', 'in', 'Gluon@@', 'NLP']
            'I':         1,    1,     1,     1,      1,     1,      1,      1
            'can':       1,    1,     1,     1,      1,     1,      1,      1
            'now':       1,    1,     1,     1,      1,     1,      1,      1
            'use':       1,    1,     1,     1,      1,     1,      1,      1
            'numpy':     1,    1,     1,     1,      1,     1,      1,      1
            'in':        1,    1,     1,     1,      1,     1,      1,      1
            'Gluon@@':   1,    1,     1,     1,      1,     1,      1,      1
            'NLP':       1,    1,     1,     1,      1,     1,      1,      1

        The mask of the second sample is

        .. code-block:: none

                       ['May', 'the', 'force', 'be', 'with', 'you', '<PAD>', '<PAD>']
            'May':        1,    1,     1,     1,      1,     1,      0,      0
            'the':        1,    1,     1,     1,      1,     1,      0,      0
            'force':      1,    1,     1,     1,      1,     1,      0,      0
            'be':         1,    1,     1,     1,      1,     1,      0,      0
            'with':       1,    1,     1,     1,      1,     1,      0,      0
            'you':        1,    1,     1,     1,      1,     1,      0,      0
            '<PAD>':      0,    0,     0,     0,      0,     0,      0,      0
            '<PAD>':      0,    0,     0,     0,      0,     0,      0,      0

        The data.

        - layout = 'NT'
            Shape (batch_size, seq_length, C)
        - layout = 'TN'
            Shape (seq_length, batch_size, C)

        Shape (batch_size,)
        Data type of the mask
        Can be 'full' or 'causal'
        The layout of the data

        Shape (batch_size, seq_length, seq_length)
    if layout == 'NT':
        batch_axis, time_axis = 0, 1
    elif layout == 'TN':
        batch_axis, time_axis = 1, 0
        raise NotImplementedError('Unsupported layout={}'.format(layout))
    if attn_type == 'full':
        if valid_length is not None:
            valid_length = valid_length.astype(dtype)
            steps = npx.arange_like(data, axis=time_axis)  # (seq_length,)
            mask1 = (npx.reshape(steps, (1, 1, -1))
                     < npx.reshape(valid_length, (-2, 1, 1)))
            mask2 = (npx.reshape(steps, (1, -1, 1))
                     < npx.reshape(valid_length, (-2, 1, 1)))
            mask = mask1 * mask2
            # TODO(sxjscience) optimize
            seq_len_ones = np.ones_like(npx.arange_like(data, axis=time_axis))  # (seq_length,)
            batch_ones = np.ones_like(npx.arange_like(data, axis=batch_axis))   # (batch_size,)
            mask = batch_ones.reshape((-1, 1, 1)) * seq_len_ones.reshape((1, -1, 1))\
                   * seq_len_ones.reshape((1, 1, -1))
    elif attn_type == 'causal':
        steps = npx.arange_like(data, axis=time_axis)
        # mask: (seq_length, seq_length)
        # batch_mask: (batch_size, seq_length)
        mask = (np.expand_dims(steps, axis=0) <= np.expand_dims(steps, axis=1)).astype(dtype)
        if valid_length is not None:
            valid_length = valid_length.astype(dtype)
            batch_mask = (np.expand_dims(steps, axis=0) < np.expand_dims(valid_length, axis=-1)).astype(dtype)
            mask = mask * np.expand_dims(batch_mask, axis=-1)
            batch_ones = np.ones_like(npx.arange_like(data, axis=batch_axis),
                                        dtype=dtype)  # (batch_size,)
            mask = mask * batch_ones.reshape((-1, 1, 1))
        raise NotImplementedError
    return mask.astype(np.bool)
Exemple #9
def multi_head_dot_attn(query, key, value,
                        dropout: float = 0.0,
                        scaled: bool = True, normalized: bool = False,
                        eps: float = 1E-6, query_head_units: Optional[int] = None,
                        layout: str = 'NKT',
                        use_einsum: bool = False):
    """Multihead dot product attention between the query, key, value.

    scaled is False, normalized is False:
        D(h_q, h_k) = <h_q, h_k>
    scaled is True, normalized is False:
        D(h_q, h_k) = <h_q, h_k> / sqrt(dim_q)
    scaled is False, normalized is True:
        D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||>
    scaled is True, normalized is True:
        D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||> / sqrt(dim_q)

    If edge_scores is provided, we will calcualte the attention as
        scores = D(h_q, h_k) + EdgeScore_{q, k}

        Query. The shape depends on the layout

        - layout is 'NKT'
            Shape (batch_size, num_heads, query_length, key_dim)
        - layout is 'NTK'
            Shape (batch_size, query_length, num_heads, key_dim)
        - layout is 'TNK'
            Shape (query_length, batch_size, num_heads, key_dim)

        Key. The shape depends on the layout

        - layout is 'NKT'
            Shape (batch_size, num_heads, mem_length, key_dim)
        - layout is 'NTK'
            Shape (batch_size, mem_length, num_heads, key_dim)
        - layout is 'TNK'
            Shape (mem_length, batch_size, num_heads, key_dim)

        Value. The shape depends on the layout

        - layout is 'NKT'
            Shape (batch_size, num_heads, mem_length, value_dim)
        - layout is 'NTK'
            Shape (batch_size, mem_length, num_heads, value_dim)
        - layout is 'TNK'
            Shape (mem_length, batch_size, num_heads, value_dim)

        Mask between query and memory. Shape (batch_size, query_length, mem_length)
        The edge attention score. Shape can be any shape that is broadcastable to
        (batch_size, num_heads, query_length, mem_length)
        Dropout rate
        Whether to divide the attention weights by the sqrt of the query dimension.
        This is first proposed in "[NIPS2017] Attention is all you need."::

        .. code-block:: none

            score = <h_q, h_k> / sqrt(dim_q)

        If turned on, the cosine distance is used, i.e::

        .. code-block:: none

            score = <h_q / ||h_q||, h_k / ||h_k||>

        The epsilon value used in L2 normalization
        The units of each query head. If it's empty, we will estimate it via the
        shape_array of the query.
        This stands for the layout of the attention cell. The shape of the input/output will depend
        on the layout. Currently, we support 'NKT', 'NTK' and 'TNK' in which
        'N' means the batch_size, 'K' means the head, and 'T' means the length dimension.
        Whether to use einsum for the computation

        - layout is 'NKT' or 'NTK'
            Shape (batch_size, query_length, num_heads * value_units)
        - layout is 'TNK'
            Shape (query_length, batch_size, num_heads * value_units)

            Shape (batch_size, num_head, query_length, mem_length)
            Shape (batch_size, num_head, query_length, mem_length)
    # TODO(sxjscience) Profile layout
    if normalized:
        query = l2_normalize(query, axis=-1, eps=eps)
        key = l2_normalize(key, axis=-1, eps=eps)
    if scaled:
        if query_head_units is None:
            raise NotImplementedError('You will need to specify query_head_units!')
            scale = math.sqrt(query_head_units)
        scale = None
    if layout == 'NKT':
        # 1. Expand the dimension of the mask:
        #   (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
        if mask is not None:
            mask = np.expand_dims(mask, axis=1).astype(np.bool)
        # 2. Calculate the attention weights
        #   Score: (B, N, L_query, C_Q) X (B, N, L_mem, C_Q) --> (B, N, L_query, L_mem)
        scores = npx.batch_dot(query, key, transpose_b=True)
        if edge_scores is not None:
            scores = scores + edge_scores
        attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale)
        attn_weights = npx.dropout(attn_weights, p=dropout)
        # 3. Calculate the context vector
        # (B, N, L_query, L_mem) X (B, N, L_mem, C_V) --> (B, L_query, N * C_V)
        if use_einsum:
            context_vec = np.einsum('bnij,bnjc->binc', attn_weights, value)
            context_vec = npx.batch_dot(attn_weights, value).transpose((0, 2, 1, 3))
        context_vec = npx.reshape(context_vec, (-2, -2, -1))
    elif layout == 'NTK':
        # 1. Expand the dimension of the mask:
        #   (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
        if mask is not None:
            mask = np.expand_dims(mask, axis=1).astype(np.bool)
        # 2. Calculate the attention weights
        #   Score: (B, L_query, N, C_Q) X (B, L_mem, N, C_Q) --> (B, N, L_query, L_mem)
        if use_einsum:
            scores = np.einsum('binc,bjnc->bnij', query, key)
            scores = npx.batch_dot(np.swapaxes(query, 1, 2), np.swapaxes(key, 1, 2),
        if edge_scores is not None:
            scores = scores + edge_scores
        attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale)
        attn_weights = npx.dropout(attn_weights, p=dropout)
        # 3. Calculate the context vector
        # (B, N, L_query, L_mem) X (B, L_mem, N, C_V) --> (B, L_query, N * C_V)
        if use_einsum:
            context_vec = np.einsum('bnij,bjnc->binc', attn_weights, value)
            context_vec = npx.batch_dot(attn_weights,
                                          np.swapaxes(value, 1, 2)).transpose((0, 2, 1, 3))
        context_vec = npx.reshape(context_vec, (-2, -2, -1))
    elif layout == 'TNK':
        # 1. Expand the dimension of the mask:
        #   (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
        if mask is not None:
            mask = np.expand_dims(mask, axis=1).astype(np.bool)
        # 2. Calculate the attention weights
        #   Score: (L_query, B, N, C_Q) X (L_mem, B, N, C_Q) --> (B, N, L_query, L_mem)
        #   This layout structure can be implemented very efficiently because B, N are consecutive
        #   to each other. To have a clear picture of what's happening, we may consider the
        #   (i, j)th element of the output
        #       out[i, j, :, :] = query[:, i, j, :] X key[:, i, j, :].T, which is just one GEMM call
        #   We can thus implement the whole kernel via a single call of batched GEMM with stride.
        if use_einsum:
            scores = np.einsum('ibnc,jbnc->bnij', query, key)
            scores = npx.batch_dot(query.transpose((1, 2, 0, 3)),
                                     key.transpose((1, 2, 3, 0)))
        if edge_scores is not None:
            scores = scores + edge_scores
        attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale)
        attn_weights = npx.dropout(attn_weights, p=dropout)
        # 3. Calculate the context vector
        # (B, N, L_query, L_mem) X (L_mem, B, N, C_V) --> (L_query, B, N * C_V)
        # Again, we can implement it via a single call to batched GEMM with stride.

        # Shape (B, N, L_query, C_V)
        if use_einsum:
            context_vec = np.einsum('bnij,jbnc->ibnc', attn_weights, value)
            context_vec = npx.batch_dot(attn_weights,
                                          value.transpose((1, 2, 0, 3))).transpose((2, 0, 1, 3))
        context_vec = npx.reshape(context_vec, (-2, -2, -1))
        raise NotImplementedError('layout="{}" is not supported! '
                                  'We only support layout = "NKT", "NTK", and "TNK".'
    return context_vec, [scores, attn_weights]
Exemple #10
def gen_mem_attn_mask(mem, mem_valid_length, data, data_valid_length=None,
                      dtype=np.float32, layout: str = 'NT'):
    """Generate the mask used for the decoder. All query slots are attended to the memory slots.

    In our implementation, 1 --> not masked, 0 --> masked

    Let's consider the data + mem with a batch of two samples:

    .. code-block:: none

        mem = [['I',   'can', 'now',   'use'],
               ['May', 'the', 'force', '<PAD>']]
        mem_valid_length =
            [4, 3]
        data =
            [['numpy', 'in',    'Gluon@@', 'NLP'  ],
             ['be',    'with',  'you',     '<PAD>']]
        data_valid_length =
            [4, 3]

    For our example, the mask of the first sample is

    .. code-block:: none

                   ['I', 'can', 'now', 'use']
        'numpy':     1,    1,     1,     1
        'in':        1,    1,     1,     1
        'Gluon@@':   1,    1,     1,     1
        'NLP':       1,    1,     1,     1

    The mask of the second sample is

    .. code-block:: none

                   ['be', 'with', 'you', '<PAD>']
        'May':        1,    1,     1,     0
        'the':        1,    1,     1,     0
        'force':      1,    1,     1,     0
        '<PAD>':      0,    0,     0,     0

       - layout = 'NT'
            Shape (batch_size, mem_length, C_mem)
       - layout = 'TN'
            Shape (mem_length, batch_size, C_mem)

    mem_valid_length :
        Shape (batch_size,)
        - layout = 'NT'
            Shape (batch_size, query_length, C_data)
        - layout = 'TN'
            Shape (query_length, batch_size, C_data)

    data_valid_length :
        Shape (batch_size,)
        Data type of the mask
        Layout of the data + mem tensor

    mask :
        Shape (batch_size, query_length, mem_length)
    if layout == 'NT':
        batch_axis, time_axis = 0, 1
    elif layout == 'TN':
        batch_axis, time_axis = 1, 0
        raise NotImplementedError('Unsupported layout={}'.format(layout))
    mem_valid_length = mem_valid_length.astype(dtype)
    mem_steps = npx.arange_like(mem, axis=time_axis)  # (mem_length,)
    data_steps = npx.arange_like(data, axis=time_axis)  # (query_length,)
    mem_mask = (npx.reshape(mem_steps, (1, 1, -1))
                < npx.reshape(mem_valid_length, (-2, 1, 1))).astype(dtype)  # (B, 1, mem_length)
    if data_valid_length is not None:
        data_valid_length = data_valid_length.astype(dtype)
        data_mask = (npx.reshape(data_steps, (1, -1, 1))
                     < npx.reshape(data_valid_length, (-2, 1, 1))).astype(dtype)  # (B, query_length, 1)
        mask = mem_mask * data_mask
        query_length_ones = np.ones_like(data_steps)
        mask = query_length_ones.reshape((1, -1, 1)) * mem_mask
    return mask.astype(np.bool)
Exemple #11
 def transpose_for_scores(self, x):
     # NT -> NTK: (B, L_seq, inner_dim) -> (B, L_seq, num_heads, n_kv)
     # TN -> TNK: (L_seq, B, inner_dim) -> (L_seq, B, num_heads, n_kv)
     return npx.reshape(x, (-2, -2, self._num_heads, -1))
    def forward(self, data, mem, rel_positions, mask, query_r_bias,

            The input data.
            layout = 'NT':
                Shape (batch_size, query_length, units)
            layout = 'TN':
                Shape (query_length, batch_size, units)
            The memory.
            layout = 'NT':
                Shape (batch_size, mem_length, units)
            layout = 'TN':
                Shape (mem_length, batch_size, units)
            The relative positions between data and [mem, data]
            Shape (query_length, mem_length + query_length).
            A positive value means that query is after the memory, i.e.,
            query_location - mem_location.
            Mask between the query and the memory + query.
            1--> will be used, 0 --> won't be used
            Shape (batch_size, query_length, mem_length + query_length)
            The query bias for calculating the relative scores
            Shape (num_heads, query_head_units)
            The key bias for calculating the relative scores.
            Shape (num_heads, query_head_units)

            - layout = 'NT'
                Shape (batch_size, query_length, units)
            - layout = 'TN'
                Shape (query_length, batch_size, units)
        if self._layout == 'NT':
            context = np.concatenate([mem, data], axis=1)
        elif self._layout == 'TN':
            context = np.concatenate([mem, data], axis=0)
            raise NotImplementedError
        if self._pre_norm:
            query = self.attn_query(self.layer_norm(data))
            key_value = self.attn_kv(self.layer_norm(context))
            key, value = np.split(key_value, 2, axis=-1)
            query = self.attn_query(data)
            key_value = self.attn_kv(context)
            key, value = np.split(key_value, 2, axis=-1)
        query = npx.reshape(query, (-2, -2, self._num_heads, -1))
        key = npx.reshape(key, (-2, -2, self._num_heads, -1))
        value = npx.reshape(value, (-2, -2, self._num_heads, -1))
        # Compute attention
        rel_score = self.rel_pos_score_cell(rel_positions,
                                            query + query_r_bias)
        out, _ = self.attn_cell(query + query_k_bias, key, value, mask,
        out = self.dropout_layer(out)
        if self._pre_norm:
            out = data + out
            out = self.layer_norm(data + out)
        out = self.ffn(out)
        return out
 def forward(self, data, indices):
     mask = indices < 3
     data = npx.reshape(data, (-1, -2), reverse=True)
     mask = np.reshape(mask, (-1, ))
     sel = nd.np._internal.boolean_mask(data, mask)
     return sel
Exemple #14
    def forward(self, data, attn_mask):

            - layout = 'NT'
                Shape (batch_size, seq_length, C_in)
            - layout = 'TN'
                Shape (seq_length, batch_size, C_in)
            The attention mask
            Shape (batch_size, seq_length, seq_length)

            - layout = 'NT'
                Shape (batch_size, seq_length, C_out)
            - layout = 'TN'
                Shape (seq_length, batch_size, C_out)
            Shape (batch_size, seq_length, seq_length)
        if self._use_bottleneck:
            bn_proj = self.in_bottleneck_proj(data)
            bn_proj = self.in_bottleneck_ln(bn_proj)
            input = bn_proj
            if self._bottleneck_strategy == 'qk_sharing':
                # for Mobile Bert
                qk_shared = self.shared_qk(data)
                qk_shared = self.shared_qk_ln(qk_shared)
                query = qk_shared
                key = qk_shared
                value = data
            elif self._bottleneck_strategy == 'from_bottleneck':
                # for Mobile Bert Tiny
                query = bn_proj
                key = bn_proj
                value = bn_proj
            elif self._bottleneck_strategy == 'from_input':
                query = data
                key = data
                value = data
                raise NotImplementedError
            input = data
            query = data
            key = data
            value = data

        query = npx.reshape(self.attn_query(query),
                            (-2, -2, self._num_heads, -1))
        key = npx.reshape(self.attn_key(key), (-2, -2, self._num_heads, -1))
        value = npx.reshape(self.attn_value(value),
                            (-2, -2, self._num_heads, -1))
        out, [_, attn_weight] = self.attention_cell(query, key, value,
        out = self.attention_proj(out)
        if not self._use_bottleneck:
            out = self.dropout_layer(out)
        out = out + input
        out = self.layer_norm(out)
        for ffn_idx in range(self._num_stacked_ffn):
            ffn = self.stacked_ffn[ffn_idx]
            out = ffn(out)

        if self._use_bottleneck:
            out = self.out_bottleneck_proj(out)
            out = self.dropout_layer(out)
            out = out + data
            out = self.out_bottleneck_ln(out)
        return out, attn_weight
Exemple #15
 def transpose_for_scores(x):
     return npx.reshape(x, (-2, -2, self._num_heads, -1))