def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: if self.wall_clock_breakdown: self.timers('moe').start() # Implement Algorithm 2 from GShard paper. d_model = input[0].shape[-1] # Initial implementation -> Reshape into S tokens by dropping sequence dimension. # Reshape into G groups so that each group can distribute tokens equally # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1 reshaped_input = input[0].reshape(-1, d_model) if self.use_tutel: self.l_aux, C, E, indices_, locations_, gates_, self.exp_counts = self.gate( reshaped_input, input[1], True) S, M = reshaped_input.size(0), reshaped_input.size(1) if not hasattr(self, '_tutel_dispatcher'): self._tutel_dispatcher = tutel_moe.fast_dispatcher( E, C, M, dispatch_dtype=reshaped_input.dtype) self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C) dispatched_input = self._tutel_dispatcher.encode(reshaped_input) else: self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate( reshaped_input, input[1]) dispatched_input = einsum("sec,sm->ecm", dispatch_mask.type_as(input[0]), reshaped_input) if self.wall_clock_breakdown: self.timers('falltoall').start() dispatched_input = _AllToAll.apply(self.group, dispatched_input) if self.wall_clock_breakdown: self.timers('falltoall').stop() self.time_falltoall = self.timers('falltoall').elapsed( reset=False) * 1000 # Re-shape after all-to-all: ecm -> gecm dispatched_input = dispatched_input.reshape(self.world_size, self.num_local_experts, -1, d_model) expert_output = self.experts(dispatched_input) if self.wall_clock_breakdown: self.timers('salltoall').start() expert_output = _AllToAll.apply(self.group, expert_output) if self.wall_clock_breakdown: self.timers('salltoall').stop() self.time_salltoall = self.timers('salltoall').elapsed( reset=False) * 1000 # Re-shape back: gecm -> ecm expert_output = expert_output.reshape( self.world_size * self.num_local_experts, -1, d_model) if self.use_tutel: combined_output = self._tutel_dispatcher.decode( expert_output.view(E * C, M)) else: combined_output = einsum("sec,ecm->sm", combine_weights.type_as(input[0]), expert_output) a = combined_output.reshape(input[0].shape) if self.wall_clock_breakdown: self.timers('moe').stop() self.time_moe = self.timers('moe').elapsed(reset=False) * 1000 return a
def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: if self.wall_clock_breakdown: self.timers('moe').start() # Implement Algorithm 2 from GShard paper. d_model = input[0].shape[-1] # Initial implementation -> Reshape into S tokens by dropping sequence dimension. # Reshape into G groups so that each group can distribute tokens equally # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1 reshaped_input = input[0].reshape(-1, d_model) if self.use_tutel: self.l_aux, C, E, indices_, locations_, gates_, self.exp_counts = self.gate( reshaped_input, input[1], True) S, M = reshaped_input.size(0), reshaped_input.size(1) if not hasattr(self, '_tutel_dispatcher'): self._tutel_dispatcher = tutel_moe.fast_dispatcher( E, C, M, dispatch_dtype=reshaped_input.dtype) self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C) dispatched_input = self._tutel_dispatcher.encode(reshaped_input) else: self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate( reshaped_input, input[1]) dispatched_input = einsum("sec,sm->ecm", dispatch_mask.type_as(input[0]), reshaped_input) if self.wall_clock_breakdown: self.timers('falltoall').start() if groups._get_expert_model_parallel_world_size() == 1: # If the non-expert is tensor-parallel, it will create # duplicate tokens on the tensor-parallel ranks. # Since our experts are not tensor-parallel, these duplicates # need to be dropped to ensure correctness. # this also doubles up as a communication optimization as we are # reducing the all-to-all communication volume. dispatched_input = drop_tokens(dispatched_input, dim=1) dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input) if self.wall_clock_breakdown: self.timers('falltoall').stop() self.time_falltoall = self.timers('falltoall').elapsed(reset=False) # Re-shape after all-to-all: ecm -> gecm dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model) expert_output = self.experts(dispatched_input) if self.wall_clock_breakdown: self.timers('salltoall').start() expert_output = _AllToAll.apply(self.ep_group, expert_output) if self.wall_clock_breakdown: self.timers('salltoall').stop() self.time_salltoall = self.timers('salltoall').elapsed(reset=False) # Re-shape back: gecm -> ecm expert_output = expert_output.reshape( self.ep_size * self.num_local_experts, -1, d_model) if groups._get_expert_model_parallel_world_size() == 1: # the dropped duplicate tokens need to be gathered on each # tensor parallel rank again for the tensor-parallel # non-expert of the next layer. expert_output = gather_tokens(expert_output, dim=1) if self.use_tutel: combined_output = self._tutel_dispatcher.decode( expert_output.view(E * C, M)) else: combined_output = einsum("sec,ecm->sm", combine_weights.type_as(input[0]), expert_output) a = combined_output.reshape(input[0].shape) if self.wall_clock_breakdown: self.timers('moe').stop() self.time_moe = self.timers('moe').elapsed(reset=False) return a