예제 #1
0
 def backward(ctx, grad_out):
     pos, local_expert_count, global_expert_count = ctx.saved_tensors
     fwd_batch_size, world_size = ctx.moe_args
     (grad_out_buf,) = fmoe_cuda.local_scatter(grad_out.contiguous(), pos)
     if world_size > 1:
         (global_grad_out_buf,) = fmoe_cuda.global_scatter(
             grad_out_buf,
             local_expert_count,
             global_expert_count,
             fwd_batch_size,
             world_size,
         )
     else:
         global_grad_out_buf = grad_out_buf
     return global_grad_out_buf, None, None, None, None, None
예제 #2
0
 def forward(
     ctx,
     inp,
     pos,
     local_expert_count,
     global_expert_count,
     fwd_batch_size,
     world_size,
 ):
     (local_input_buf,) = fmoe_cuda.local_scatter(inp, pos)
     if world_size > 1:
         (global_input_buf,) = fmoe_cuda.global_scatter(
             local_input_buf,
             local_expert_count,
             global_expert_count,
             fwd_batch_size,
             world_size,
         )
     else:
         global_input_buf = local_input_buf
     ctx.moe_args = inp.shape[0], world_size
     variables = (pos, local_expert_count, global_expert_count)
     ctx.save_for_backward(*variables)
     return global_input_buf