def forward( ctx, global_output_buf, pos, local_expert_count, global_expert_count, local_batch_size, world_size, ): if world_size > 1: (local_output_buf,) = fmoe_cuda.global_gather( global_output_buf, local_expert_count, global_expert_count, local_batch_size, world_size, ) else: local_output_buf = global_output_buf (output,) = fmoe_cuda.local_gather(local_output_buf, pos) ctx.moe_args = (global_output_buf.shape[0], world_size) variables = (pos, local_expert_count, global_expert_count) ctx.save_for_backward(*variables) return output
def backward(ctx, global_grad_in): (pos, local_expert_count, global_expert_count) = ctx.saved_tensors (local_batch_size, world_size) = ctx.moe_args if world_size > 1: (local_grad_in,) = fmoe_cuda.global_gather( global_grad_in, local_expert_count, global_expert_count, local_batch_size, world_size, ) else: local_grad_in = global_grad_in (grad_in,) = fmoe_cuda.local_gather(local_grad_in, pos) return grad_in, None, None, None, None, None