Esempio n. 1
0
    def forward(self,
                query,
                key,
                value,
                attn_mask=None,
                use_cache=False,
                cache=None):
        r"""
        Applies multi-head attention to map queries and a set of key-value pairs
        to outputs.
        """
        key = query if key is None else key
        value = query if value is None else value
        # compute q ,k ,v
        if use_cache is False:
            if self.fuse:
                q, k, v = self._fuse_prepare_qkv(query)
            else:
                q, k, v = self._prepare_qkv(query, key, value, use_cache,
                                            cache)
        else:
            q, k, v, cache = self._prepare_qkv(query, key, value, use_cache,
                                               cache)
        # scale dot product attention
        product = layers.matmul(x=q,
                                y=k,
                                transpose_y=True,
                                alpha=self.head_dim**-0.5)

        # if attn_mask is not None:
        # product = product + attn_mask
        # weights = F.softmax(product)

        weights = incubate.softmax_mask_fuse_upper_triangle(product)

        if self.dropout:
            with get_rng_state_tracker().rng_state('local_seed'):
                weights = F.dropout(weights,
                                    self.dropout,
                                    training=self.training,
                                    mode="upscale_in_train")

        out = tensor.matmul(weights, v)

        # combine heads
        out = tensor.transpose(out, perm=[0, 2, 1, 3])
        out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])

        # project to output
        out = self.out_proj(out)

        outs = [out]
        if self.need_weights:
            outs.append(weights)
        if use_cache:
            outs.append(cache)
        return out if len(outs) == 1 else tuple(outs)
Esempio n. 2
0
    def forward(self,
                tgt,
                memory=None,
                tgt_mask=None,
                use_cache=False,
                cache=None):
        residual = tgt

        if self.normalize_before:
            tgt = self.norm1(tgt)

        if use_cache is False:
            tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache)
        else:
            tgt, incremental_cache = self.self_attn(tgt, tgt, tgt, tgt_mask,
                                                    use_cache, cache)

        with get_rng_state_tracker().rng_state('global_seed'):
            tgt = residual + self.dropout1(tgt)

        if not self.normalize_before:
            tgt = self.norm1(tgt)

        residual = tgt
        if self.normalize_before:
            tgt = self.norm2(tgt)

        if self.expert_mode:
            tgt = self.moe_mlp(tgt)
        else:
            with get_rng_state_tracker().rng_state('global_seed'):
                tgt = self.dropout2(
                    self.linear2(F.gelu(self.linear1(tgt), approximate=True)))

        tgt = residual + tgt

        if not self.normalize_before:
            tgt = self.norm2(tgt)

        return tgt if use_cache is False else (tgt, incremental_cache)
Esempio n. 3
0
def set_hyrbid_parallel_seed(basic_seed, data_world_rank, mp_rank, pp_rank):
    assert args.device != "cpu"

    random.seed(basic_seed + data_world_rank)
    np.random.seed(basic_seed + data_world_rank)
    paddle.seed(basic_seed + data_world_rank)

    # local_seed/ global_seed is used to control dropout in ModelParallel
    local_seed = basic_seed + 123 + mp_rank * 10 + pp_rank * 1000
    global_seed = basic_seed + data_world_rank
    tracker = get_rng_state_tracker()
    tracker.add('global_seed', global_seed)
    tracker.add('local_seed', local_seed)
Esempio n. 4
0
    def forward(self, input_ids, position_ids=None):
        if position_ids is None:
            ones = paddle.ones_like(input_ids, dtype="int64")
            seq_length = paddle.cumsum(ones, axis=-1)
            position_ids = seq_length - ones

        input_embedings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = input_embedings + position_embeddings

        with get_rng_state_tracker().rng_state('global_seed'):
            embeddings = self.dropout(embeddings)

        return embeddings
Esempio n. 5
0
def set_hyrbid_parallel_seed(basic_seed, dp_rank, mp_rank, pp_rank):
    assert args.device != "cpu"

    random.seed(basic_seed + dp_rank)
    np.random.seed(basic_seed + dp_rank)
    paddle.seed(basic_seed + dp_rank)
    from paddle.distributed.fleet import meta_parallel
    meta_parallel.model_parallel_random_seed(basic_seed + dp_rank +
                                             1000 * mp_rank)

    # local_seed/ global_seed is used to control dropout in ModelParallel
    local_seed = basic_seed + 123 + mp_rank * 10 + pp_rank * 1000
    global_seed = basic_seed + dp_rank
    tracker = get_rng_state_tracker()
    tracker.add('global_seed', global_seed)
    tracker.add('local_seed', local_seed)