예제 #1
0
    def forward(self, data: np.ndarray,
                prev: Optional[np.ndarray]) -> np.ndarray:
        """
        Apply processing sequence to data with optional previous input.

        :param data: Input data. Shape: (batch, length, num_hidden).
        :param prev: Previous data. Shape: (batch, length, num_hidden).
        :return: Processed data. Shape: (batch, length, num_hidden).
        """
        if not self.sequence:
            return data

        if prev is None:
            assert 'r' not in self.sequence, "Residual connection not allowed if no previous value given."

        for step in self.sequence:

            if step == "r":
                data = data + prev

            elif step == "n":
                data = self.layer_norm(data)

            elif step == "d":
                if self.dropout > 0.0:
                    data = npx.dropout(data, p=self.dropout)
            else:
                raise ValueError("Unknown step in sequence: %s" % step)

        return data
예제 #2
0
    def forward(self,
                queries: np.ndarray,
                key_values: np.ndarray,
                heads: np.ndarray,
                lengths: Optional[np.ndarray] = None,
                bias: Optional[np.ndarray] = None):

        # (n*h, lq, lk)
        logits = npx.interleaved_matmul_encdec_qk(queries,
                                                  key_values,
                                                  heads=heads)

        if bias is not None:
            logits = logits + bias

        if lengths is not None:
            # required shape for lengths: (n*h, lq); required dtype: int32
            probs = npx.softmax(logits,
                                axis=-1,
                                length=lengths,
                                use_length=True)
        else:
            probs = npx.softmax(logits, axis=-1)

        probs = npx.dropout(probs,
                            p=self.dropout) if self.dropout > 0.0 else probs

        # key_values: (lk, n, dv * 2)
        # probs: (n*h, lq, lk)
        # result: (n, lq, dv)
        return npx.interleaved_matmul_encdec_valatt(key_values,
                                                    probs,
                                                    heads=heads)
예제 #3
0
 def forward(self, x: np.ndarray) -> np.ndarray:
     h = self.ff1(x)
     h = self.act(h)
     if self.use_glu:
         h = h * self.linear(x)
     if self.dropout > 0.0:
         h = npx.dropout(h, p=self.dropout)
     y = self.ff2(h)
     return y
예제 #4
0
    def forward(self, data, valid_length):
        # positional embedding
        data = self.pos_embedding(data, None)

        if self.config.dropout_prepost > 0.0:
            data = npx.dropout(data=data, p=self.config.dropout_prepost)

        # (batch_size * heads, seq_len)
        att_valid_length = layers.prepare_source_valid_lengths(valid_length, data,
                                                               num_heads=self.config.attention_heads)

        data = np.transpose(data, axes=(1, 0, 2))
        for block in self.layers:
            data = block(data, att_valid_length)

        data = self.final_process(data, None)
        data = np.transpose(data, axes=(1, 0, 2))
        return data, valid_length
예제 #5
0
    def forward(self, data, valid_length):  # pylint: disable=arguments-differ
        # We will catch the optional factor weights in kwargs
        average_factors_embeds = []  # type: List[np.ndarray]
        concat_factors_embeds = []  # type: List[np.ndarray]
        sum_factors_embeds = []  # type: List[np.ndarray]
        if self.config.num_factors > 1 and self.config.factor_configs is not None:
            data, *data_factors = (np.squeeze(x, axis=2) for x in np.split(data, self.config.num_factors, axis=2))
            for i, (factor_data, factor_config) in enumerate(zip(data_factors,
                                                                 self.config.factor_configs)):
                factor_weight = self.factor_weights[i]
                factor_embedding = npx.embedding(factor_data,
                                                 input_dim=factor_config.vocab_size,
                                                 weight=factor_weight.data(),
                                                 output_dim=factor_config.num_embed)
                if factor_config.combine == C.FACTORS_COMBINE_CONCAT:
                    concat_factors_embeds.append(factor_embedding)
                elif factor_config.combine == C.FACTORS_COMBINE_SUM:
                    sum_factors_embeds.append(factor_embedding)
                elif factor_config.combine == C.FACTORS_COMBINE_AVERAGE:
                    average_factors_embeds.append(factor_embedding)
                else:
                    raise ValueError("Unknown combine value for factors: %s" % factor_config.combine)
        else:
            data = np.squeeze(data, axis=2)

        embed = npx.embedding(data,
                              weight=self.weight.data(),
                              input_dim=self.config.vocab_size,
                              output_dim=self.config.num_embed,
                              dtype=self._dtype,
                              sparse_grad=False)

        if self.config.num_factors > 1 and self.config.factor_configs is not None:
            if average_factors_embeds:
                embed = npx.add_n(embed, *average_factors_embeds) / (len(average_factors_embeds) + 1)
            if sum_factors_embeds:
                embed = npx.add_n(embed, *sum_factors_embeds)
            if concat_factors_embeds:
                embed = np.concatenate((embed, *concat_factors_embeds), axis=2)

        if self.config.dropout > 0:
            embed = npx.dropout(data=embed, p=self.config.dropout)

        return embed, np.copy(valid_length)  # See https://github.com/apache/incubator-mxnet/issues/14228
예제 #6
0
def multi_head_dot_attn(query, key, value,
                        mask=None,
                        edge_scores=None,
                        dropout: float = 0.0,
                        scaled: bool = True, normalized: bool = False,
                        eps: float = 1E-6, query_head_units: Optional[int] = None,
                        layout: str = 'NKT',
                        use_einsum: bool = False):
    """Multihead dot product attention between the query, key, value.

    scaled is False, normalized is False:
        D(h_q, h_k) = <h_q, h_k>
    scaled is True, normalized is False:
        D(h_q, h_k) = <h_q, h_k> / sqrt(dim_q)
    scaled is False, normalized is True:
        D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||>
    scaled is True, normalized is True:
        D(h_q, h_k) = <h_q / ||h_q||, h_k / ||h_k||> / sqrt(dim_q)

    If edge_scores is provided, we will calcualte the attention as
        scores = D(h_q, h_k) + EdgeScore_{q, k}

    Parameters
    ----------
    query
        Query. The shape depends on the layout

        - layout is 'NKT'
            Shape (batch_size, num_heads, query_length, key_dim)
        - layout is 'NTK'
            Shape (batch_size, query_length, num_heads, key_dim)
        - layout is 'TNK'
            Shape (query_length, batch_size, num_heads, key_dim)

    key
        Key. The shape depends on the layout

        - layout is 'NKT'
            Shape (batch_size, num_heads, mem_length, key_dim)
        - layout is 'NTK'
            Shape (batch_size, mem_length, num_heads, key_dim)
        - layout is 'TNK'
            Shape (mem_length, batch_size, num_heads, key_dim)

    value
        Value. The shape depends on the layout

        - layout is 'NKT'
            Shape (batch_size, num_heads, mem_length, value_dim)
        - layout is 'NTK'
            Shape (batch_size, mem_length, num_heads, value_dim)
        - layout is 'TNK'
            Shape (mem_length, batch_size, num_heads, value_dim)

    mask
        Mask between query and memory. Shape (batch_size, query_length, mem_length)
    edge_scores
        The edge attention score. Shape can be any shape that is broadcastable to
        (batch_size, num_heads, query_length, mem_length)
    dropout
        Dropout rate
    scaled
        Whether to divide the attention weights by the sqrt of the query dimension.
        This is first proposed in "[NIPS2017] Attention is all you need."::

        .. code-block:: none

            score = <h_q, h_k> / sqrt(dim_q)

    normalized
        If turned on, the cosine distance is used, i.e::

        .. code-block:: none

            score = <h_q / ||h_q||, h_k / ||h_k||>

    eps
        The epsilon value used in L2 normalization
    query_head_units
        The units of each query head. If it's empty, we will estimate it via the
        shape_array of the query.
    layout
        This stands for the layout of the attention cell. The shape of the input/output will depend
        on the layout. Currently, we support 'NKT', 'NTK' and 'TNK' in which
        'N' means the batch_size, 'K' means the head, and 'T' means the length dimension.
    use_einsum
        Whether to use einsum for the computation

    Returns
    -------
    context_vec
        - layout is 'NKT' or 'NTK'
            Shape (batch_size, query_length, num_heads * value_units)
        - layout is 'TNK'
            Shape (query_length, batch_size, num_heads * value_units)

    additional_info
        scores:
            Shape (batch_size, num_head, query_length, mem_length)
        attn_weight:
            Shape (batch_size, num_head, query_length, mem_length)
    """
    # TODO(sxjscience) Profile layout
    if normalized:
        query = l2_normalize(query, axis=-1, eps=eps)
        key = l2_normalize(key, axis=-1, eps=eps)
    if scaled:
        if query_head_units is None:
            raise NotImplementedError('You will need to specify query_head_units!')
        else:
            scale = math.sqrt(query_head_units)
    else:
        scale = None
    if layout == 'NKT':
        # 1. Expand the dimension of the mask:
        #   (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
        if mask is not None:
            mask = np.expand_dims(mask, axis=1).astype(np.bool)
        # 2. Calculate the attention weights
        #   Score: (B, N, L_query, C_Q) X (B, N, L_mem, C_Q) --> (B, N, L_query, L_mem)
        scores = npx.batch_dot(query, key, transpose_b=True)
        if edge_scores is not None:
            scores = scores + edge_scores
        attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale)
        attn_weights = npx.dropout(attn_weights, p=dropout)
        # 3. Calculate the context vector
        # (B, N, L_query, L_mem) X (B, N, L_mem, C_V) --> (B, L_query, N * C_V)
        if use_einsum:
            context_vec = np.einsum('bnij,bnjc->binc', attn_weights, value)
        else:
            context_vec = npx.batch_dot(attn_weights, value).transpose((0, 2, 1, 3))
        context_vec = npx.reshape(context_vec, (-2, -2, -1))
    elif layout == 'NTK':
        # 1. Expand the dimension of the mask:
        #   (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
        if mask is not None:
            mask = np.expand_dims(mask, axis=1).astype(np.bool)
        # 2. Calculate the attention weights
        #   Score: (B, L_query, N, C_Q) X (B, L_mem, N, C_Q) --> (B, N, L_query, L_mem)
        if use_einsum:
            scores = np.einsum('binc,bjnc->bnij', query, key)
        else:
            scores = npx.batch_dot(np.swapaxes(query, 1, 2), np.swapaxes(key, 1, 2),
                                   transpose_b=True)
        if edge_scores is not None:
            scores = scores + edge_scores
        attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale)
        attn_weights = npx.dropout(attn_weights, p=dropout)
        # 3. Calculate the context vector
        # (B, N, L_query, L_mem) X (B, L_mem, N, C_V) --> (B, L_query, N * C_V)
        if use_einsum:
            context_vec = np.einsum('bnij,bjnc->binc', attn_weights, value)
        else:
            context_vec = npx.batch_dot(attn_weights,
                                          np.swapaxes(value, 1, 2)).transpose((0, 2, 1, 3))
        context_vec = npx.reshape(context_vec, (-2, -2, -1))
    elif layout == 'TNK':
        # 1. Expand the dimension of the mask:
        #   (B, L_query, L_mem) --> (B, 1, L_query, L_mem)
        if mask is not None:
            mask = np.expand_dims(mask, axis=1).astype(np.bool)
        # 2. Calculate the attention weights
        #   Score: (L_query, B, N, C_Q) X (L_mem, B, N, C_Q) --> (B, N, L_query, L_mem)
        #   This layout structure can be implemented very efficiently because B, N are consecutive
        #   to each other. To have a clear picture of what's happening, we may consider the
        #   (i, j)th element of the output
        #       out[i, j, :, :] = query[:, i, j, :] X key[:, i, j, :].T, which is just one GEMM call
        #   We can thus implement the whole kernel via a single call of batched GEMM with stride.
        if use_einsum:
            scores = np.einsum('ibnc,jbnc->bnij', query, key)
        else:
            scores = npx.batch_dot(query.transpose((1, 2, 0, 3)),
                                     key.transpose((1, 2, 3, 0)))
        if edge_scores is not None:
            scores = scores + edge_scores
        attn_weights = masked_softmax(scores, mask, axis=-1, temperature=scale)
        attn_weights = npx.dropout(attn_weights, p=dropout)
        # 3. Calculate the context vector
        # (B, N, L_query, L_mem) X (L_mem, B, N, C_V) --> (L_query, B, N * C_V)
        # Again, we can implement it via a single call to batched GEMM with stride.

        # Shape (B, N, L_query, C_V)
        if use_einsum:
            context_vec = np.einsum('bnij,jbnc->ibnc', attn_weights, value)
        else:
            context_vec = npx.batch_dot(attn_weights,
                                          value.transpose((1, 2, 0, 3))).transpose((2, 0, 1, 3))
        context_vec = npx.reshape(context_vec, (-2, -2, -1))
    else:
        raise NotImplementedError('layout="{}" is not supported! '
                                  'We only support layout = "NKT", "NTK", and "TNK".'
                                  .format(layout))
    return context_vec, [scores, attn_weights]
예제 #7
0
    def forward(
            self, step_input: np.ndarray,
            states: List[np.ndarray]) -> Tuple[np.ndarray, List[np.ndarray]]:
        mask = None
        if self.inference_only:
            steps, source_valid_length, *other = states
            source_encoded = None  # use constant pre-computed key value projections from the states
            enc_att_kv = other[:self.config.num_layers]
            autoregr_states = other[self.config.num_layers:]
        else:
            if any(layer.needs_mask for layer in self.layers):
                mask = self.autoregressive_bias(
                    step_input)  # mask: (1, length, length)
            steps, source_encoded, source_valid_length, *autoregr_states = states
            enc_att_kv = [None for _ in range(self.config.num_layers)]

        if any(layer.num_state_tensors > 1 for layer in self.layers):
            # separates autoregressive states by layer
            states_iter = iter(autoregr_states)
            autoregr_states = [
                list(islice(states_iter, 0, layer.num_state_tensors))
                for layer in self.layers
            ]

        # (batch_size * heads, query_length)
        source_valid_length = layers.prepare_source_valid_lengths(
            source_valid_length,
            step_input,
            num_heads=self.config.attention_heads)

        # target: (batch_size, length, model_size)
        target = self.pos_embedding(step_input, steps)
        # (length, batch_size, model_size)
        target = np.transpose(target, axes=(1, 0, 2))

        if self.config.dropout_prepost > 0.0:
            target = npx.dropout(data=target, p=self.config.dropout_prepost)

        new_autoregr_states = []
        for layer, layer_autoregr_state, layer_enc_att_kv in zip(
                self.layers, autoregr_states, enc_att_kv):
            target, new_layer_autoregr_state = layer(target, mask,
                                                     source_encoded,
                                                     source_valid_length,
                                                     layer_autoregr_state,
                                                     layer_enc_att_kv)

            new_autoregr_states += [*new_layer_autoregr_state]

        target = self.final_process(target, None)
        target = np.transpose(target, axes=(1, 0, 2))

        # Inference: increment steps by 1 (discarded in training)
        steps = steps + 1

        if self.inference_only:
            # pass in cached encoder states
            encoder_attention_keys_values = states[2:2 +
                                                   self.config.num_layers]
            new_states = [
                steps, states[1]
            ] + encoder_attention_keys_values + new_autoregr_states
        else:
            encoder_outputs = states[1]
            encoder_valid_length = states[2]
            new_states = [steps, encoder_outputs, encoder_valid_length
                          ] + new_autoregr_states

        return target, new_states