def forward(self, query, key, value, attn_mask=None, use_cache=False, cache=None): r""" Applies multi-head attention to map queries and a set of key-value pairs to outputs. """ key = query if key is None else key value = query if value is None else value # compute q ,k ,v if use_cache is False: if self.fuse: q, k, v = self._fuse_prepare_qkv(query) else: q, k, v = self._prepare_qkv(query, key, value, use_cache, cache) else: q, k, v, cache = self._prepare_qkv(query, key, value, use_cache, cache) # scale dot product attention product = layers.matmul(x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5) # if attn_mask is not None: # product = product + attn_mask # weights = F.softmax(product) weights = incubate.softmax_mask_fuse_upper_triangle(product) if self.dropout: with get_rng_state_tracker().rng_state('local_seed'): weights = F.dropout(weights, self.dropout, training=self.training, mode="upscale_in_train") out = tensor.matmul(weights, v) # combine heads out = tensor.transpose(out, perm=[0, 2, 1, 3]) out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) # project to output out = self.out_proj(out) outs = [out] if self.need_weights: outs.append(weights) if use_cache: outs.append(cache) return out if len(outs) == 1 else tuple(outs)
def forward(self, tgt, memory=None, tgt_mask=None, use_cache=False, cache=None): residual = tgt if self.normalize_before: tgt = self.norm1(tgt) if use_cache is False: tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache) else: tgt, incremental_cache = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache) with get_rng_state_tracker().rng_state('global_seed'): tgt = residual + self.dropout1(tgt) if not self.normalize_before: tgt = self.norm1(tgt) residual = tgt if self.normalize_before: tgt = self.norm2(tgt) if self.expert_mode: tgt = self.moe_mlp(tgt) else: with get_rng_state_tracker().rng_state('global_seed'): tgt = self.dropout2( self.linear2(F.gelu(self.linear1(tgt), approximate=True))) tgt = residual + tgt if not self.normalize_before: tgt = self.norm2(tgt) return tgt if use_cache is False else (tgt, incremental_cache)
def set_hyrbid_parallel_seed(basic_seed, data_world_rank, mp_rank, pp_rank): assert args.device != "cpu" random.seed(basic_seed + data_world_rank) np.random.seed(basic_seed + data_world_rank) paddle.seed(basic_seed + data_world_rank) # local_seed/ global_seed is used to control dropout in ModelParallel local_seed = basic_seed + 123 + mp_rank * 10 + pp_rank * 1000 global_seed = basic_seed + data_world_rank tracker = get_rng_state_tracker() tracker.add('global_seed', global_seed) tracker.add('local_seed', local_seed)
def forward(self, input_ids, position_ids=None): if position_ids is None: ones = paddle.ones_like(input_ids, dtype="int64") seq_length = paddle.cumsum(ones, axis=-1) position_ids = seq_length - ones input_embedings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) embeddings = input_embedings + position_embeddings with get_rng_state_tracker().rng_state('global_seed'): embeddings = self.dropout(embeddings) return embeddings
def set_hyrbid_parallel_seed(basic_seed, dp_rank, mp_rank, pp_rank): assert args.device != "cpu" random.seed(basic_seed + dp_rank) np.random.seed(basic_seed + dp_rank) paddle.seed(basic_seed + dp_rank) from paddle.distributed.fleet import meta_parallel meta_parallel.model_parallel_random_seed(basic_seed + dp_rank + 1000 * mp_rank) # local_seed/ global_seed is used to control dropout in ModelParallel local_seed = basic_seed + 123 + mp_rank * 10 + pp_rank * 1000 global_seed = basic_seed + dp_rank tracker = get_rng_state_tracker() tracker.add('global_seed', global_seed) tracker.add('local_seed', local_seed)