Esempio n. 1
0
def test_conversion(args, hf_model, gluon_model):
    logging.info('testing conversion...')
    # create dummy input
    batch_size = 6
    src_length = 128
    tgt_length = 8
    vocab_size = hf_model.shared.weight.shape[0]
    src_data = np.random.randint(1, vocab_size, (batch_size, src_length))
    src_valid_length = np.random.randint(src_length // 2, src_length,
                                         (batch_size, ))
    tgt_data = np.random.randint(1, vocab_size, (batch_size, tgt_length))
    tgt_valid_length = np.random.randint(tgt_length // 2, tgt_length,
                                         (batch_size, ))
    enc_attn_mask = npx.arange_like(src_data,
                                    axis=-1) < src_valid_length.reshape(-1, 1)
    dec_attn_mask = npx.arange_like(tgt_data,
                                    axis=-1) < tgt_valid_length.reshape(-1, 1)
    # test T5Model forward pass
    hf_model.eval()  # disable dropout
    hf_out = hf_model(
        input_ids=torch.from_numpy(src_data.asnumpy()),
        attention_mask=torch.from_numpy(enc_attn_mask.asnumpy()),
        decoder_input_ids=torch.from_numpy(tgt_data.asnumpy()),
        decoder_attention_mask=torch.from_numpy(
            dec_attn_mask.asnumpy()))['last_hidden_state'].detach().numpy()
    gl_out = gluon_model(src_data, src_valid_length, tgt_data,
                         tgt_valid_length)
    for i in range(batch_size):
        assert np.allclose(hf_out[i, :tgt_valid_length[i].item(), :],
                           gl_out[i, :tgt_valid_length[i].item(), :], 1E-3,
                           1E-3)
    logging.info('pass')
Esempio n. 2
0
 def _get_relative_position(self, hidden_states):
     query_position = np.expand_dims(npx.arange_like(hidden_states,
                                                     axis=self.time_axis),
                                     axis=-1)
     mem_position = np.expand_dims(npx.arange_like(hidden_states,
                                                   axis=self.time_axis),
                                   axis=0)
     relative_position = mem_position - query_position
     return relative_position.astype(np.int32)
Esempio n. 3
0
    def forward(self, x, layer_states):
        """

        Parameters
        ----------
        x
            - layout = 'NT'
                Shape (batch_size, seq_length, C_in)
            - layout = 'TN'
                Shape (seq_length, batch_size, C_in)

        layer_states
            - layout = 'NT'
                Shape (2, batch_size, prev_len, C_in)
            - layout = 'TN'
                Shape (2, prev_len, batch_size, C_in)
        """
        x = self.ln(x)
        if self._layout == 'NT':
            batch_axis, time_axis = 0, 1
            prev_len = npx.shape_array(layer_states)[2]
        else:
            batch_axis, time_axis = 1, 0
            prev_len = npx.shape_array(layer_states)[1]

        query, key, value = np.split(self.qkv(x), 3, axis=-1)
        if layer_states is not None:
            prev_key, prev_value = layer_states[0], layer_states[1]
            key = np.concatenate([prev_key, key], axis=time_axis)
            value = np.concatenate([prev_value, value], axis=time_axis)
        new_states = np.stack([key, value], axis=0)

        # gen mask
        query_pos = npx.arange_like(query, axis=time_axis)
        if prev_len is not None:
            query_pos = query_pos + prev_len
        key_pos = npx.arange_like(key, axis=time_axis)
        # (query_len, key_len)
        mask = (npx.reshape(key_pos,
                            (1, -1)) <= npx.reshape(query_pos,
                                                    (-1, 1))).astype(
                                                        self._dtype)
        # broadcast to (batch_size, query_len, key_len)
        mask = npx.broadcast_like(np.expand_dims(mask, axis=0),
                                  query,
                                  lhs_axes=0,
                                  rhs_axes=batch_axis)

        query = npx.reshape(query, (-2, -2, self._num_heads, -1))
        key = npx.reshape(key, (-2, -2, self._num_heads, -1))
        value = npx.reshape(value, (-2, -2, self._num_heads, -1))

        out, [_, attn_weight] = self.attention_cell(query, key, value, mask)
        out = self.out_proj(out)
        out = self.hidden_dropout(out)

        return out, new_states
Esempio n. 4
0
def gen_rel_position(data, past_data=None, dtype=np.int32, layout='NT'): 
    """Create a matrix of relative position for RelAttentionScoreCell. 
    
    The relative position is defined as the index difference: `mem_i` - `query_j`. 
    Note, though, that the implementation here makes sense in self-attention's setting, 
    but not in cross-attention's. Hence, both `mem_i` and `query_j` are time indices from 
    `data` (or, in incremental decoding's case, the concatenated sequence from the current 
    stepwise `data` and the previous steps `past_data`). 

    Parameters
    ----------
    data
        The data. Under incremental decoding, seq_length = 1. 

        - layout = 'NT'
            Shape (batch_size, seq_length, C)
        - layout = 'TN'
            Shape (seq_length, batch_size, C)
    past_data
        This is only used under incremental decoding. Stacked data from previous steps. 
    dtype
        Data type of the mask
    layout
        Layout of the data + past_data

    Returns
    -------
    relative_position :
        Shape (query_length, mem_length) where query_length = mem_length = seq_length
    """
    time_axis = 1 if layout == 'NT' else 0
    if past_data is None: 
        position = npx.arange_like(data, axis=time_axis)
    else: 
        # for incremental decoding only, where past data is of the shape: 
        # NT(NTK): (B, L_seq, num_heads, n_kv) -> (B, L_seq, inner_dim)
        # TN(TNK): (L_seq, B, num_heads, n_kv) -> (L_seq, B, inner_dim)
        past_data = npx.reshape(past_data, (-2, -2, -5))
        position = npx.arange_like(
            np.concatenate([past_data, data], axis=time_axis), 
            axis=time_axis
        )
    query_position = np.expand_dims(position, axis=-1)
    mem_position = np.expand_dims(position, axis=0)
    relative_position = mem_position - query_position
    return relative_position.astype(np.int32) # shape (L_seq, L_seq)
def test_arange_like():
    A = np.zeros((INT_OVERFLOW, 2))
    A.attach_grad()
    with mx.autograd.record():
        B = npx.arange_like(A)
    assert B.shape == (INT_OVERFLOW, 2)
    assert B[100][0] == 200
    B.backward()
    assert A.grad.shape == (INT_OVERFLOW, 2)
    assert A.grad[0][0] == 0
Esempio n. 6
0
 def forward(self, x: np.ndarray) -> np.ndarray:
     # Shape: (length, 1)
     length_array = npx.arange_like(x, axis=1)
     # matrix with lower triangle and main diagonal set to 0, upper triangle set to 1
     # Shape: (length, length)
     bias = npx.broadcast_greater(np.expand_dims(length_array, axis=0),
                                  np.expand_dims(length_array, axis=1))
     bias = bias * -C.LARGE_VALUES[self._dtype]
     bias = np.expand_dims(bias, axis=0)
     return npx.stop_gradient(bias)
Esempio n. 7
0
 def _get_relative_position(self,
                            hidden_states,
                            mem_states=None,
                            past_key_value=None):
     if past_key_value is None:
         query_position = np.expand_dims(npx.arange_like(
             hidden_states, axis=self.time_axis),
                                         axis=-1)
     else:
         # for incremental decoding only, where past key and past value are of shape
         # NT(NTK): (B, L_seq, num_heads, n_kv); TN(TNK): (L_seq, B, num_heads, n_kv)
         query_position = npx.arange_like(np.concatenate(
             [hidden_states, past_key_value[0]], axis=self.time_axis),
                                          axis=self.time_axis)
         query_position = np.expand_dims(query_position, axis=-1)
     mem_position = np.expand_dims(npx.arange_like(
         hidden_states if mem_states is None else mem_states,
         axis=self.time_axis),
                                   axis=0)
     relative_position = mem_position - query_position
     return relative_position.astype(np.int32)
Esempio n. 8
0
    def init_state_from_encoder(
            self,
            encoder_outputs: np.ndarray,
            encoder_valid_length: Optional[np.ndarray] = None,
            target_embed: Optional[np.ndarray] = None) -> List[np.ndarray]:
        """
        Returns the initial states given encoder output. States for teacher-forced training are encoder outputs
        and a valid length mask for encoder outputs.
        At inference, this method returns the following state tuple:
        valid length bias, step state,
        [projected encoder attention keys, projected encoder attention values] * num_layers,
        [autoregressive state dummies] * num_layers.

        :param encoder_outputs: Encoder outputs. Shape: (batch, source_length, encoder_dim).
        :param encoder_valid_length: Valid lengths of encoder outputs. Shape: (batch,).
        :param target_embed: Target-side embedding layer output. Shape: (batch, target_length, target_embedding_dim).
        :return: Initial states.
        """
        if target_embed is None:  # Inference: initial step = 0. Shape: (batch_size, 1)
            steps = np.expand_dims(np.zeros_like(encoder_valid_length), axis=1)
        else:  # Training: steps up to target length. Shape: (1, target_length)
            steps = np.expand_dims(npx.arange_like(target_embed, axis=1),
                                   axis=0)

        if self.inference_only:
            # Encoder projection caching, therefore we don't pass the encoder_outputs
            states = [steps, encoder_valid_length]

            for layer in self.layers:
                enc_att_kv = layer.enc_attention.ff_kv(encoder_outputs)
                states.append(np.transpose(enc_att_kv, axes=(1, 0, 2)))
        else:
            # NO encoder projection caching
            states = [
                steps,
                np.transpose(encoder_outputs, axes=(1, 0, 2)),
                encoder_valid_length
            ]

        _batch_size = encoder_outputs.shape[0]
        _ctx = encoder_outputs.ctx
        _dtype = encoder_outputs.dtype
        dummy_autoregr_states = [
            np.zeros(layer.get_states_shape(_batch_size),
                     ctx=_ctx,
                     dtype=_dtype) for layer in self.layers
            for _ in range(layer.num_state_tensors)
        ]

        states += dummy_autoregr_states
        return states
Esempio n. 9
0
    def get_initial_embedding(self, inputs, token_types=None):
        """Get the initial token embeddings that considers the token type and positional embeddings

        Parameters
        ----------
        inputs
            - layout = 'NT'
                Shape (batch_size, seq_length)
            - layout = 'TN'
                Shape (seq_length, batch_size)

        token_types
            The type of tokens. If None, it will be initialized as all zero.

            - layout = 'NT'
                Shape (batch_size, seq_length)
            - layout = 'TN'
                Shape (seq_length, batch_size)

        Returns
        -------
        embedding
            The initial embedding that will be fed into the encoder

            - layout = 'NT'
                Shape (batch_size, seq_length, C_embed)
            - layout = 'TN'
                Shape (seq_length, batch_size, C_embed)

        """
        if self.layout == 'NT':
            time_axis, batch_axis = 1, 0
        else:
            time_axis, batch_axis = 0, 1
        embedding = self.word_embed(inputs)
        if token_types is None:
            token_types = np.zeros_like(inputs)
        type_embedding = self.token_type_embed(token_types)
        embedding = embedding + type_embedding
        if self.pos_embed_type is not None:
            positional_embedding = self.token_pos_embed(npx.arange_like(inputs, axis=time_axis))
            positional_embedding = np.expand_dims(positional_embedding, axis=batch_axis)
            embedding = embedding + positional_embedding
        # Extra layer normalization plus dropout
        embedding = self.embed_layer_norm(embedding)
        embedding = self.embed_dropout(embedding)
        return embedding
Esempio n. 10
0
def add_vectors_by_position(data, increment, positions):
    """Scatter each batch with the given positions.

    data[i, positions[i, j], ...] += increment[i, j, ...]

    Parameters
    ----------
    F
    data
        Input tensor of the array to be updated.
        Shape (batch_size, seq_length, ...)
    increment
        Input tensor of token ids
        Shape (batch_size, num_disp_position, ...)
    positions
        Input tensor of the positions.
        Shape (batch_size, num_disp_position).
        For each sample in the batch, the values in this tensor must not exceed
        the length of the sequence.

    Returns
    -------
    out
        The updated result.
        Shape (batch_size, seq_length, ...)
    """
    # Here, we use index_add to disperse the output from data:
    # Need to compute
    #   out[i, masked_position[i, j], :] = in[i, j, :]
    # Thus, construct an indices with shape [2, batch_size * num_masked_position], where
    #     indices[0, i * num_masked_position + j] = i
    #     indices[1, i * num_masked_position + j] = masked_position[i, j]
    # And convert data to the shape of the (batch_size * num_masked_position, )
    # Then, out = npx.index_add(data, indices, increment)
    positions = positions.astype(np.int32)
    # batch_idx.shape = (batch_size, 1) as [[0], [1], [2], ...]
    batch_idx = np.expand_dims(npx.arange_like(positions, axis=0),
                               axis=1).astype(np.int32)
    batch_idx = batch_idx + np.zeros_like(positions)
    indices = np.stack([batch_idx.reshape((-1, )), positions.reshape((-1, ))])
    out = npx.index_add(data, indices, npx.reshape(increment, (-5, -4)))
    return out
Esempio n. 11
0
def select_vectors_by_position(data, positions):
    """Select each batch with the given positions.

    Once advanced indexing can be hybridized, we can revise the implementation.

    out[i, j, ...] = data[i, positions[i, j], ...]

    Parameters
    ----------
    data
        Input tensor of contextualized token embeddings
        Shape (batch_size, seq_length, ...)
    positions
        Input tensor of the positions.
        Shape (batch_size, num_sel_positions).
        For each sample in the batch, the values in this tensor must not exceed
        the length of the sequence.

    Returns
    -------
    out
        The selection result.
        Shape (batch_size, num_sel_positions, ...)
    """
    # Here, we use gather_nd to select the output from data:
    # Need to compute
    #   out[i, j, :] = in[i, masked_position[i, j], :]
    # Thus, construct a indices with shape [2, batch_size, num_masked_position], where
    #     indices[0, i, j] = i
    #     indices[1, i, j] = masked_position[i, j]
    # Then, out = gather_nd(in, indices)
    positions = positions.astype(np.int32)
    # batch_idx.shape = (batch_size, 1) as [[0], [1], [2], ...]
    batch_idx = np.expand_dims(npx.arange_like(positions, axis=0),
                               axis=1).astype(np.int32)
    batch_idx = batch_idx + np.zeros_like(positions)
    indices = np.stack([batch_idx, positions])
    # TODO(sxjscience) We can revise the implementation to advanced indexing
    #  once the bug in MXNet is solved:
    #  https://github.com/apache/incubator-mxnet/issues/18919
    out = npx.gather_nd(data, indices)
    return out
Esempio n. 12
0
def update_vectors_by_position(data, val, positions):
    """
    Update each batch with the given positions. Considered as a reversed process of
    "select_vectors_by_position", this is an operator similar to "add_vectors_by_position"
    that updates the results instead of adding.

    data[i, positions[i, j], :] = val[i, j, :]

    Parameters
    ----------
    F
    data:
        Input tensor of the array to be updated.
        Shape (batch_size, seq_length)
    val
        Input tensor of token ids
        Shape (batch_size, num_disp_position)
    positions
        Input tensor of the positions.
        Shape (batch_size, num_disp_position).
        For each sample in the batch, the values in this tensor must not exceed
        the length of the sequence.

    Returns
    -------
    out
        The updated result.
        Shape (batch_size, seq_length)
    """
    positions = positions.astype(np.int32)
    # batch_idx.shape = (batch_size, 1) as [[0], [1], [2], ...]
    batch_idx = np.expand_dims(npx.arange_like(positions, axis=0),
                               axis=1).astype(np.int32)
    batch_idx = batch_idx + np.zeros_like(positions)
    indices = np.stack([batch_idx.reshape((-1, )), positions.reshape((-1, ))])

    out = npx.index_update(data, indices, npx.reshape(val, (-5, -4)))
    return out
Esempio n. 13
0
    def get_initial_embedding(self, inputs):
        """Get the initial token embeddings that considers the token type and positional embeddings

        Parameters
        ----------
        inputs
            - layout = 'NT'
                Shape (batch_size, seq_length)
            - layout = 'TN'
                Shape (seq_length, batch_size)

        Returns
        -------
        embedding
            The initial embedding that will be fed into the encoder

            - layout = 'NT'
                Shape (batch_size, seq_length, C)
            - layout = 'TN'
                Shape (seq_length, batch_size, C)

        """
        if self._layout == 'NT':
            batch_axis, time_axis = 0, 1
        else:
            batch_axis, time_axis = 1, 0
        embedding = self.word_embed(inputs)
        if self.pos_embed_type:
            positional_embedding = self.pos_embed(
                npx.arange_like(inputs, axis=time_axis))
            positional_embedding = np.expand_dims(positional_embedding,
                                                  axis=batch_axis)
            embedding = embedding + positional_embedding
        if self.encoder_normalize_before:
            embedding = self.embed_ln(embedding)
        embedding = self.embed_dropout(embedding)

        return embedding
Esempio n. 14
0
    def get_initial_embedding(self, inputs, prev_len):
        """Get the initial token embeddings that considers the token type and positional embeddings

        Parameters
        ----------
        inputs
            - layout = 'NT'
                Shape (batch_size, seq_length)
            - layout = 'TN'
                Shape (seq_length, batch_size)

        prev_len
            The previous length. It will be a scalar.

        Returns
        -------
        embedding
            - layout = 'NT'
                Shape (batch_size, seq_length, C)
            - layout = 'TN'
                Shape (seq_length, batch_size, C)

        """
        embedding = self._embed(inputs)
        if self._layout == 'NT':
            batch_axis, time_axis = 0, 1
        else:
            batch_axis, time_axis = 1, 0
        if self._pos_embed_type is not None:
            pos = npx.arange_like(inputs, axis=time_axis)
            if prev_len is not None:
                pos = pos + prev_len
            positional_embedding = self._pos_embed(pos)
            positional_embedding = np.expand_dims(positional_embedding,
                                                  axis=batch_axis)
            embedding = embedding + positional_embedding
        embedding = self._embed_dropout(embedding)
        return embedding
Esempio n. 15
0
    def get_initial_embedding(self, inputs, token_types=None):
        """Get the initial token embeddings that considers the token type and positional embeddings

        Parameters
        ----------
        F
        inputs
            - layout = 'NT'
                Shape (batch_size, seq_length)
            - layout = 'TN'
                Shape (seq_length, batch_size)
        token_types
            - layout = 'NT'
                Shape (batch_size, seq_length)
            - layout = 'TN'
                Shape (seq_length, batch_size)
            If None, it will be initialized as all zero

        Returns
        -------
        embedding
            The initial embedding that will be fed into the encoder
        """
        if self._layout == 'NT':
            batch_axis, time_axis = 0, 1
        elif self._layout == 'TN':
            batch_axis, time_axis = 1, 0
        else:
            raise NotImplementedError
        word_embedding = self.word_embed(inputs)

        if self.trigram_embed:
            if self._layout == 'NT':
                word_embedding = np.concatenate([
                    np.pad(word_embedding[:, 1:],
                           ((0, 0), (0, 1), (0, 0))), word_embedding,
                    np.pad(word_embedding[:, :-1], ((0, 0), (1, 0), (0, 0)))
                ],
                                                axis=-1)
            elif self._layout == 'TN':
                word_embedding = np.concatenate([
                    np.pad(word_embedding[1:, :],
                           ((0, 1), (0, 0), (0, 0))), word_embedding,
                    np.pad(word_embedding[:-1, :], ((1, 0), (0, 0), (0, 0)))
                ],
                                                axis=-1)
            else:
                raise NotImplementedError
        # Projecting the embedding into units only for word embedding
        if self.trigram_embed or self.embed_size != self.units:
            word_embedding = self.embed_factorized_proj(word_embedding)

        if token_types is None:
            token_types = np.zeros_like(inputs)
        type_embedding = self.token_type_embed(token_types)
        embedding = word_embedding + type_embedding
        if self.pos_embed_type is not None:
            positional_embedding =\
                self.token_pos_embed(npx.arange_like(embedding, axis=time_axis))
            positional_embedding = np.expand_dims(positional_embedding,
                                                  axis=batch_axis)
            embedding = embedding + positional_embedding
        # Extra layer normalization plus dropout
        embedding = self.embed_layer_norm(embedding)
        embedding = self.embed_dropout(embedding)
        return embedding
Esempio n. 16
0
    def forward(self, rel_positions, query=None):
        """Forward function

        Parameters
        ----------
        rel_positions
            The relative shifts. Shape (query_length, mem_length).
            Each element represents the shift between the :math:`i-th` element of query and
            the :math:`j-th` element of memory.
        query
            The query for computing the relative scores. The shape depends on the layout.
            If we use T5 attention, the query will not be used.

        Returns
        -------
        rel_scores
            The relative attention scores
            Can have shape (batch_size, num_heads, query_length, mem_length)
            or (num_heads, query_length, mem_length)
        """
        if self._method == 'transformer_xl' or self._method == 'shaw':
            assert query is not None, 'Must specify query if method={}'.format(self._method)
            if self._bidirectional:
                if self._max_distance is not None:
                    rel_positions = np.clip(rel_positions,
                                              a_min=-self._max_distance, a_max=self._max_distance)
            else:
                if self._max_distance is not None:
                    rel_positions = np.clip(rel_positions,
                                              a_min=0, a_max=self._max_distance)
            # uniq_rel.shape = (#uniq,), rev_index.shape = (L_q, L_m)
            uniq_rel, rev_index = np.unique(rel_positions, return_inverse=True)

            uniq_rel_pos_embed = self._rel_pos_embed(uniq_rel)
            if self._method == 'transformer_xl':
                uniq_rel_pos_embed = self._rel_proj(self._dropout_layer(uniq_rel_pos_embed))
            # Shape (#uniq, K, C_q)
            uniq_rel_pos_embed = npx.reshape(uniq_rel_pos_embed,
                                               (-2, self._num_heads, self._head_query_units))
            # Calculate the dot-product between query and the relative positional embeddings.
            # After the calculation, rel_score.shape = (L_q, #uniq, N, K)
            if self._layout == 'NKT':
                # query_for_rel: (N, K, L_q, C_q)
                if self._use_einsum:
                    rel_score = np.einsum('bnid,jnd->ijbn', query, uniq_rel_pos_embed)
                else:
                    rel_score = np.transpose(
                        np.matmul(query,
                                    np.transpose(uniq_rel_pos_embed, (1, 2, 0))),
                        (2, 3, 0, 1)
                    )
            elif self._layout == 'NTK':
                # query_for_rel: (N, L_q, K, C_q)
                if self._use_einsum:
                    rel_score = np.einsum('bind,jnd->ijbn', query, uniq_rel_pos_embed)
                else:
                    rel_score = np.transpose(
                        np.matmul(np.swapaxes(query, 1, 2),
                                    np.transpose(uniq_rel_pos_embed, (1, 2, 0))),
                        (2, 3, 0, 1)
                    )
            elif self._layout == 'TNK':
                # query_for_rel: (L_q, N, K, C_q)
                if self._use_einsum:
                    rel_score = np.einsum('ibnd,jnd->ijbn', query, uniq_rel_pos_embed)
                else:
                    rel_score = np.transpose(
                        np.matmul(np.transpose(query, (1, 2, 0, 3)),
                                    np.transpose(uniq_rel_pos_embed, (1, 2, 0))),
                        (2, 3, 0, 1)
                    )
            else:
                raise NotImplementedError
            # We use gather_nd to select the elements
            # TODO(sxjscience) Use advanced indexing once available
            rev_index = npx.reshape_like(rev_index, rel_positions).astype(np.int32)
            query_idx = np.expand_dims(npx.arange_like(rel_positions, axis=0).astype(np.int32),
                                         axis=-1) + np.zeros_like(rev_index)
            rel_score = npx.gather_nd(rel_score, np.stack([query_idx, rev_index]))
            rel_score = np.transpose(rel_score, (2, 3, 0, 1))
        elif self._method == 't5':
            # shape is (K, L_q, L_m)
            rel_score = self._rel_pos_embed(rel_positions).transpose((2, 0, 1))
        else:
            raise NotImplementedError
        return rel_score
Esempio n. 17
0
def gen_self_attn_mask(data,
                       valid_length=None,
                       dtype: type = np.float32,
                       attn_type: str = 'full',
                       layout: str = 'NT'):
    """Generate the mask used for the encoder, i.e, self-attention.

    In our implementation, 1 --> not masked, 0 --> masked

    Let's consider the data with two samples:

    .. code-block:: none

        data =
            [['I',   'can', 'now',   'use', 'numpy', 'in',  'Gluon@@', 'NLP'  ],
             ['May', 'the', 'force', 'be',  'with',  'you', '<PAD>',   '<PAD>']]
        valid_length =
            [8, 6]

    - attn_type = 'causal'
        Each token will attend to itself + the tokens before.
        It will not attend to tokens in the future.

        For our example, the mask of the first sample is

        .. code-block:: none

                       ['I', 'can', 'now', 'use', 'numpy', 'in', 'Gluon@@', 'NLP']
            'I':         1,    0,     0,     0,      0,     0,      0,      0
            'can':       1,    1,     0,     0,      0,     0,      0,      0
            'now':       1,    1,     1,     0,      0,     0,      0,      0
            'use':       1,    1,     1,     1,      0,     0,      0,      0
            'numpy':     1,    1,     1,     1,      1,     0,      0,      0
            'in':        1,    1,     1,     1,      1,     1,      0,      0
            'Gluon@@':   1,    1,     1,     1,      1,     1,      1,      0
            'NLP':       1,    1,     1,     1,      1,     1,      1,      1

        The mask of the second sample is

        .. code-block:: none

                       ['May', 'the', 'force', 'be', 'with', 'you', '<PAD>', '<PAD>']
            'May':        1,    0,     0,     0,      0,     0,      0,      0
            'the':        1,    1,     0,     0,      0,     0,      0,      0
            'force':      1,    1,     1,     0,      0,     0,      0,      0
            'be':         1,    1,     1,     1,      0,     0,      0,      0
            'with':       1,    1,     1,     1,      1,     0,      0,      0
            'you':        1,    1,     1,     1,      1,     1,      0,      0
            '<PAD>':      0,    0,     0,     0,      0,     0,      0,      0
            '<PAD>':      0,    0,     0,     0,      0,     0,      0,      0


    - attn_type = 'full'
        Each token will attend to both the tokens before and in the future

        For our example, the mask of the first sample is

        .. code-block:: none

                       ['I', 'can', 'now', 'use', 'numpy', 'in', 'Gluon@@', 'NLP']
            'I':         1,    1,     1,     1,      1,     1,      1,      1
            'can':       1,    1,     1,     1,      1,     1,      1,      1
            'now':       1,    1,     1,     1,      1,     1,      1,      1
            'use':       1,    1,     1,     1,      1,     1,      1,      1
            'numpy':     1,    1,     1,     1,      1,     1,      1,      1
            'in':        1,    1,     1,     1,      1,     1,      1,      1
            'Gluon@@':   1,    1,     1,     1,      1,     1,      1,      1
            'NLP':       1,    1,     1,     1,      1,     1,      1,      1

        The mask of the second sample is

        .. code-block:: none

                       ['May', 'the', 'force', 'be', 'with', 'you', '<PAD>', '<PAD>']
            'May':        1,    1,     1,     1,      1,     1,      0,      0
            'the':        1,    1,     1,     1,      1,     1,      0,      0
            'force':      1,    1,     1,     1,      1,     1,      0,      0
            'be':         1,    1,     1,     1,      1,     1,      0,      0
            'with':       1,    1,     1,     1,      1,     1,      0,      0
            'you':        1,    1,     1,     1,      1,     1,      0,      0
            '<PAD>':      0,    0,     0,     0,      0,     0,      0,      0
            '<PAD>':      0,    0,     0,     0,      0,     0,      0,      0

    Parameters
    ----------
    data
        The data.

        - layout = 'NT'
            Shape (batch_size, seq_length, C)
        - layout = 'TN'
            Shape (seq_length, batch_size, C)

    valid_length
        Shape (batch_size,)
    dtype
        Data type of the mask
    attn_type
        Can be 'full' or 'causal'
    layout
        The layout of the data

    Returns
    -------
    mask
        Shape (batch_size, seq_length, seq_length)
    """
    if layout == 'NT':
        batch_axis, time_axis = 0, 1
    elif layout == 'TN':
        batch_axis, time_axis = 1, 0
    else:
        raise NotImplementedError('Unsupported layout={}'.format(layout))
    if attn_type == 'full':
        if valid_length is not None:
            valid_length = valid_length.astype(dtype)
            steps = npx.arange_like(data, axis=time_axis)  # (seq_length,)
            mask1 = (npx.reshape(steps, (1, 1, -1))
                     < npx.reshape(valid_length, (-2, 1, 1)))
            mask2 = (npx.reshape(steps, (1, -1, 1))
                     < npx.reshape(valid_length, (-2, 1, 1)))
            mask = mask1 * mask2
        else:
            # TODO(sxjscience) optimize
            seq_len_ones = np.ones_like(npx.arange_like(data, axis=time_axis))  # (seq_length,)
            batch_ones = np.ones_like(npx.arange_like(data, axis=batch_axis))   # (batch_size,)
            mask = batch_ones.reshape((-1, 1, 1)) * seq_len_ones.reshape((1, -1, 1))\
                   * seq_len_ones.reshape((1, 1, -1))
    elif attn_type == 'causal':
        steps = npx.arange_like(data, axis=time_axis)
        # mask: (seq_length, seq_length)
        # batch_mask: (batch_size, seq_length)
        mask = (np.expand_dims(steps, axis=0) <= np.expand_dims(steps, axis=1)).astype(dtype)
        if valid_length is not None:
            valid_length = valid_length.astype(dtype)
            batch_mask = (np.expand_dims(steps, axis=0) < np.expand_dims(valid_length, axis=-1)).astype(dtype)
            mask = mask * np.expand_dims(batch_mask, axis=-1)
        else:
            batch_ones = np.ones_like(npx.arange_like(data, axis=batch_axis),
                                        dtype=dtype)  # (batch_size,)
            mask = mask * batch_ones.reshape((-1, 1, 1))
    else:
        raise NotImplementedError
    return mask.astype(np.bool)
Esempio n. 18
0
def gen_mem_attn_mask(mem, mem_valid_length, data, data_valid_length=None,
                      dtype=np.float32, layout: str = 'NT'):
    """Generate the mask used for the decoder. All query slots are attended to the memory slots.

    In our implementation, 1 --> not masked, 0 --> masked

    Let's consider the data + mem with a batch of two samples:

    .. code-block:: none

        mem = [['I',   'can', 'now',   'use'],
               ['May', 'the', 'force', '<PAD>']]
        mem_valid_length =
            [4, 3]
        data =
            [['numpy', 'in',    'Gluon@@', 'NLP'  ],
             ['be',    'with',  'you',     '<PAD>']]
        data_valid_length =
            [4, 3]

    For our example, the mask of the first sample is

    .. code-block:: none

                   ['I', 'can', 'now', 'use']
        'numpy':     1,    1,     1,     1
        'in':        1,    1,     1,     1
        'Gluon@@':   1,    1,     1,     1
        'NLP':       1,    1,     1,     1

    The mask of the second sample is

    .. code-block:: none

                   ['be', 'with', 'you', '<PAD>']
        'May':        1,    1,     1,     0
        'the':        1,    1,     1,     0
        'force':      1,    1,     1,     0
        '<PAD>':      0,    0,     0,     0


    Parameters
    ----------
    mem
       - layout = 'NT'
            Shape (batch_size, mem_length, C_mem)
       - layout = 'TN'
            Shape (mem_length, batch_size, C_mem)

    mem_valid_length :
        Shape (batch_size,)
    data
        - layout = 'NT'
            Shape (batch_size, query_length, C_data)
        - layout = 'TN'
            Shape (query_length, batch_size, C_data)

    data_valid_length :
        Shape (batch_size,)
    dtype
        Data type of the mask
    layout
        Layout of the data + mem tensor

    Returns
    -------
    mask :
        Shape (batch_size, query_length, mem_length)
    """
    if layout == 'NT':
        batch_axis, time_axis = 0, 1
    elif layout == 'TN':
        batch_axis, time_axis = 1, 0
    else:
        raise NotImplementedError('Unsupported layout={}'.format(layout))
    mem_valid_length = mem_valid_length.astype(dtype)
    mem_steps = npx.arange_like(mem, axis=time_axis)  # (mem_length,)
    data_steps = npx.arange_like(data, axis=time_axis)  # (query_length,)
    mem_mask = (npx.reshape(mem_steps, (1, 1, -1))
                < npx.reshape(mem_valid_length, (-2, 1, 1))).astype(dtype)  # (B, 1, mem_length)
    if data_valid_length is not None:
        data_valid_length = data_valid_length.astype(dtype)
        data_mask = (npx.reshape(data_steps, (1, -1, 1))
                     < npx.reshape(data_valid_length, (-2, 1, 1))).astype(dtype)  # (B, query_length, 1)
        mask = mem_mask * data_mask
    else:
        query_length_ones = np.ones_like(data_steps)
        mask = query_length_ones.reshape((1, -1, 1)) * mem_mask
    return mask.astype(np.bool)