def backward(ctx, grad_out): pos, local_expert_count, global_expert_count = ctx.saved_tensors fwd_batch_size, world_size = ctx.moe_args (grad_out_buf,) = fmoe_cuda.local_scatter(grad_out.contiguous(), pos) if world_size > 1: (global_grad_out_buf,) = fmoe_cuda.global_scatter( grad_out_buf, local_expert_count, global_expert_count, fwd_batch_size, world_size, ) else: global_grad_out_buf = grad_out_buf return global_grad_out_buf, None, None, None, None, None
def forward( ctx, inp, pos, local_expert_count, global_expert_count, fwd_batch_size, world_size, ): (local_input_buf,) = fmoe_cuda.local_scatter(inp, pos) if world_size > 1: (global_input_buf,) = fmoe_cuda.global_scatter( local_input_buf, local_expert_count, global_expert_count, fwd_batch_size, world_size, ) else: global_input_buf = local_input_buf ctx.moe_args = inp.shape[0], world_size variables = (pos, local_expert_count, global_expert_count) ctx.save_for_backward(*variables) return global_input_buf