def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]: payload = tuple(nested_flatten((ctx.saved_tensors, grad_outputs))) grad_inputs = ctx.stub.backward( runtime_pb2.ExpertRequest( uid=ctx.uid, tensors=[serialize_torch_tensor(tensor) for tensor in payload])) deserialized_grad_inputs = [ deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors ] return (DUMMY, None, None, *deserialized_grad_inputs)
def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]: inputs_and_grad_outputs = tuple( nested_flatten((ctx.saved_tensors, grad_outputs))) backward_schema = tuple( nested_flatten( (ctx.info["forward_schema"], ctx.info["outputs_schema"]))) serialized_tensors = [ serialize_torch_tensor(tensor, proto.compression) for tensor, proto in zip(inputs_and_grad_outputs, backward_schema) ] grad_inputs = ctx.stub.backward( runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)) deserialized_grad_inputs = [ deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors ] return (DUMMY, None, None, None, *deserialized_grad_inputs)
def forward(ctx, dummy: torch.Tensor, uid: str, stub: runtime_grpc.ConnectionHandlerStub, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: # Note: *inputs are flattened input tensors that follow the expert's info['input_schema'] inputs = tuple( map(torch.Tensor.detach, inputs)) # detach to avoid pickling the computation graph ctx.uid, ctx.stub = uid, stub ctx.save_for_backward(*inputs) outputs = stub.forward( runtime_pb2.ExpertRequest( uid=ctx.uid, tensors=[serialize_torch_tensor(tensor) for tensor in inputs])) deserialized_outputs = [ deserialize_torch_tensor(tensor) for tensor in outputs.tensors ] return tuple(deserialized_outputs)
def forward(ctx, dummy: torch.Tensor, uid: str, stub: runtime_grpc.ConnectionHandlerStub, info: Dict[str, Any], *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: # Note: *inputs are flattened input tensors that follow the expert's info['input_schema'] # detach to avoid pickling the computation graph inputs = tuple(tensor.cpu().detach() for tensor in inputs) ctx.uid, ctx.stub, ctx.info = uid, stub, info ctx.save_for_backward(*inputs) serialized_tensors = [ serialize_torch_tensor(inp, proto.compression) for inp, proto in zip(inputs, nested_flatten(info["forward_schema"])) ] outputs = stub.forward( runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)) deserialized_outputs = [ deserialize_torch_tensor(tensor) for tensor in outputs.tensors ] return tuple(deserialized_outputs)
async def rpc_download_state( self, request: averaging_pb2.DownloadRequest, context: grpc.ServicerContext ) -> AsyncIterator[averaging_pb2.DownloadData]: """ Get the up-to-date trainer state from a peer. The state consists of two parts: (serialized_metadata, tensors) - serialized_metadata is a small serialized bytestring meant to store scalars and hyperparameters - tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics """ chunk_size_bytes = self.matchmaking_kwargs.get( 'chunk_size_bytes', DEFAULT_CHUNK_SIZE_BYTES) metadata, tensors = await self._get_current_state_from_host_process() for tensor in tensors: for part in split_for_streaming(serialize_torch_tensor(tensor), chunk_size_bytes): if metadata is not None: yield averaging_pb2.DownloadData(tensor_part=part, metadata=metadata) metadata = None else: yield averaging_pb2.DownloadData(tensor_part=part)