def _fuse_prepare_qkv(self, query): mix_layer = self.qkv_proj(query) mix_layer = paddle.reshape_(mix_layer, [0, 0, self.num_heads, 3 * self.head_dim]) mix_layer = paddle.transpose(mix_layer, [0, 2, 1, 3]) q, k, v = paddle.split(mix_layer, num_or_sections=3, axis=-1) return q, k, v
def forward(self, x): """Define how the head is going to run. """ x = self.fcn(x) x = paddle.reshape_(x, (x.shape[0], -1)) # N,C,1,1 --> N,C return x
def forward_net(self, imgs): # NOTE: As the num_segs is an attribute of dataset phase, and didn't pass to build_head phase, should obtain it from imgs(paddle.Tensor) now, then call self.head method. num_segs = imgs.shape[ 1] # imgs.shape=[N,T,C,H,W], for most commonly case imgs = paddle.reshape_(imgs, [-1] + list(imgs.shape[2:])) if self.backbone != None: feature = self.backbone(imgs) else: feature = imgs if self.head != None: cls_score = self.head(feature, num_segs) else: cls_score = None return cls_score
def forward(self, inp): # inp shape: b * s * m assert len(inp.shape) == 3 origin_shape = inp.shape inp = inp.reshape_([-1, origin_shape[2]]) mp_rank = 0 mp_size = 1 if self.mp_group is not None: mp_rank = self.mp_group.rank mp_size = self.mp_group.nranks if mp_size > 1: if in_dygraph_mode(): inp = EagerSlice.apply(inp, mp_rank, mp_size, self.mp_group) else: inp = Slice.apply(inp, mp_rank, mp_size, self.mp_group) value, gate = self.gate(inp) ( pos, local_expert_count, global_expert_count, fwd_expert_count, fwd_batch_size, ) = prepare_forward(gate, self.num_expert, self.world_size, self.group) topk = 1 if len(gate.shape) == 2: topk = gate.shape[1] if pos.shape != [0]: temp_pos = pos // topk else: temp_pos = pos assert topk == self.top_k if in_dygraph_mode(): x = EagerMoEScatter.apply(inp, temp_pos, local_expert_count, global_expert_count, fwd_batch_size, self.world_size, self.group) else: x = MoEScatter.apply(inp, temp_pos, local_expert_count, global_expert_count, fwd_batch_size, self.world_size, self.group) d_model = self.d_model def experts_fwd(x, fwd_expert_count, experts): if x.shape[0] == 0: return x y = [] last_index = 0 assert isinstance(fwd_expert_count, np.ndarray) assert len(experts) == len(fwd_expert_count) for idx, expert_count in enumerate(fwd_expert_count): if expert_count <= 0: continue y.append(experts[idx](x[last_index:expert_count + last_index])) last_index = expert_count + last_index return paddle.concat(y, axis=0) if self.recompute_interval <= 0 or x.shape[0] == 0: x = experts_fwd(x, fwd_expert_count.numpy(), self.experts) else: x = _hp_recompute(experts_fwd, x, fwd_expert_count.numpy(), self.experts) out_batch_size = inp.shape[0] if len(gate.shape) == 2: out_batch_size *= gate.shape[1] if in_dygraph_mode(): x = EagerMoEGather.apply(x, pos, local_expert_count, global_expert_count, out_batch_size, self.world_size, self.group) else: x = MoEGather.apply(x, pos, local_expert_count, global_expert_count, out_batch_size, self.world_size, self.group) x = x.reshape([-1, self.top_k, d_model]) value = value.reshape([x.shape[0], 1, self.top_k]) x = paddle.bmm(value, x).reshape([-1, d_model]) if mp_size > 1: if in_dygraph_mode(): x = EagerAllGather.apply(x, mp_rank, mp_size, self.mp_group) else: x = AllGather.apply(x, mp_rank, mp_size, self.mp_group) x = paddle.reshape_(x, origin_shape) return x
def inplace_api_processing(self, var): return paddle.reshape_(var, [-1])