async def forward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext): inputs = [ deserialize_torch_tensor(tensor) for tensor in request.tensors ] future = self.experts[request.uid].forward_pool.submit_task(*inputs) serialized_response = [ serialize_torch_tensor(tensor) for tensor in await future ] return runtime_pb2.ExpertResponse(tensors=serialized_response)
async def backward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext): inputs_and_grad_outputs = [ deserialize_torch_tensor(tensor) for tensor in request.tensors ] future = self.experts[request.uid].backward_pool.submit_task( *inputs_and_grad_outputs) serialized_response = [ serialize_torch_tensor(tensor, proto.compression, allow_inplace=True) for tensor, proto in zip( await future, nested_flatten(self.experts[request.uid].grad_inputs_schema)) ] return runtime_pb2.ExpertResponse(tensors=serialized_response)