def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]: payload = tuple(nested_flatten((ctx.saved_tensors, grad_outputs))) grad_inputs = ctx.stub.backward( runtime_pb2.ExpertRequest( uid=ctx.uid, tensors=[serialize_torch_tensor(tensor) for tensor in payload])) deserialized_grad_inputs = [ deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors ] return (DUMMY, None, None, *deserialized_grad_inputs)
def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]: inputs_and_grad_outputs = tuple( nested_flatten((ctx.saved_tensors, grad_outputs))) backward_schema = tuple( nested_flatten( (ctx.info["forward_schema"], ctx.info["outputs_schema"]))) serialized_tensors = [ serialize_torch_tensor(tensor, proto.compression) for tensor, proto in zip(inputs_and_grad_outputs, backward_schema) ] grad_inputs = ctx.stub.backward( runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)) deserialized_grad_inputs = [ deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors ] return (DUMMY, None, None, None, *deserialized_grad_inputs)
def forward(ctx, dummy: torch.Tensor, uid: str, stub: runtime_grpc.ConnectionHandlerStub, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: # Note: *inputs are flattened input tensors that follow the expert's info['input_schema'] inputs = tuple( map(torch.Tensor.detach, inputs)) # detach to avoid pickling the computation graph ctx.uid, ctx.stub = uid, stub ctx.save_for_backward(*inputs) outputs = stub.forward( runtime_pb2.ExpertRequest( uid=ctx.uid, tensors=[serialize_torch_tensor(tensor) for tensor in inputs])) deserialized_outputs = [ deserialize_torch_tensor(tensor) for tensor in outputs.tensors ] return tuple(deserialized_outputs)
def forward(ctx, dummy: torch.Tensor, uid: str, stub: runtime_grpc.ConnectionHandlerStub, info: Dict[str, Any], *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: # Note: *inputs are flattened input tensors that follow the expert's info['input_schema'] # detach to avoid pickling the computation graph inputs = tuple(tensor.cpu().detach() for tensor in inputs) ctx.uid, ctx.stub, ctx.info = uid, stub, info ctx.save_for_backward(*inputs) serialized_tensors = [ serialize_torch_tensor(inp, proto.compression) for inp, proto in zip(inputs, nested_flatten(info["forward_schema"])) ] outputs = stub.forward( runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)) deserialized_outputs = [ deserialize_torch_tensor(tensor) for tensor in outputs.tensors ] return tuple(deserialized_outputs)
async def _load_state_from_peers(self, future: MPFuture): key_manager = self._matchmaking.group_key_manager peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None) peer_priority = { peer: float(info.value) for peer, info in peer_priority.items() if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int)) } if not isinstance(peer_priority, dict) or len(peer_priority) == 0: logger.info( f"Averager could not load state from peers: peer dict is absent or corrupted {peer_priority}." ) future.set_result(None) return metadata = None for peer in sorted(peer_priority.keys(), key=peer_priority.get, reverse=True): if peer != self.endpoint: logger.info(f"Downloading parameters from peer {peer}") stream = None try: leader_stub = ChannelCache.get_stub( peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True) stream = leader_stub.rpc_download_state( averaging_pb2.DownloadRequest()) current_tensor_parts, tensors = [], [] async for message in stream: if message.metadata: metadata = self.serializer.loads(message.metadata) if message.tensor_part.dtype and current_tensor_parts: # tensor_part.dtype indicates the start of the new tensor, so we should wrap up this one tensors.append( deserialize_torch_tensor( combine_from_streaming( current_tensor_parts))) current_tensor_parts = [] current_tensor_parts.append(message.tensor_part) if current_tensor_parts: tensors.append( deserialize_torch_tensor( combine_from_streaming(current_tensor_parts))) future.set_result((metadata, tensors)) self.last_updated = get_dht_time() return except grpc.aio.AioRpcError as e: logger.info(f"Failed to download state from {peer} - {e}") finally: if stream is not None: await stream.code() else: logger.warning( "Averager could not load state from peers: found no active peers." ) future.set_result(None)