def test_serialize_tensor(): tensor = torch.randn(512, 12288) serialized_tensor = serialize_torch_tensor(tensor, CompressionType.NONE) for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10**9]: chunks = list( hivemind.split_for_streaming(serialized_tensor, chunk_size)) assert len( chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1 restored = hivemind.combine_from_streaming(chunks) assert torch.allclose(deserialize_torch_tensor(restored), tensor) chunk_size = 30 * 1024 serialized_tensor = serialize_torch_tensor(tensor, CompressionType.FLOAT16) chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size)) assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1 restored = hivemind.combine_from_streaming(chunks) assert torch.allclose(deserialize_torch_tensor(restored), tensor, rtol=0, atol=1e-2) tensor = torch.randint(0, 100, (512, 1, 1)) serialized_tensor = serialize_torch_tensor(tensor, CompressionType.NONE) chunks = list(hivemind.split_for_streaming(serialized_tensor, chunk_size)) assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1 restored = hivemind.combine_from_streaming(chunks) assert torch.allclose(deserialize_torch_tensor(restored), tensor) scalar = torch.tensor(1.) serialized_scalar = serialize_torch_tensor(scalar, CompressionType.NONE) assert torch.allclose(deserialize_torch_tensor(serialized_scalar), scalar) serialized_scalar = serialize_torch_tensor(scalar, CompressionType.FLOAT16) assert torch.allclose(deserialize_torch_tensor(serialized_scalar), scalar)
async def rpc_download_state( self, request: averaging_pb2.DownloadRequest, context: grpc.ServicerContext ) -> AsyncIterator[averaging_pb2.DownloadData]: """ Get the up-to-date trainer state from a peer. The state consists of two parts: (serialized_metadata, tensors) - serialized_metadata is a small serialized bytestring meant to store scalars and hyperparameters - tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics """ if not self.allow_state_sharing: return # deny request and direct peer to the next prospective averager chunk_size_bytes = self.matchmaking_kwargs.get( 'chunk_size_bytes', DEFAULT_CHUNK_SIZE_BYTES) metadata, tensors = await self._get_current_state_from_host_process() for tensor in tensors: for part in split_for_streaming(serialize_torch_tensor(tensor), chunk_size_bytes): if metadata is not None: yield averaging_pb2.DownloadData(tensor_part=part, metadata=metadata) metadata = None else: yield averaging_pb2.DownloadData(tensor_part=part)
def test_split_parts(): tensor = torch.randn(910, 512) serialized_tensor_part = serialize_torch_tensor(tensor, allow_inplace=False) chunks1 = list( hivemind.utils.split_for_streaming(serialized_tensor_part, 16384)) assert len(chunks1) == int( np.ceil(tensor.numel() * tensor.element_size() / 16384)) chunks2 = list( hivemind.utils.split_for_streaming(serialized_tensor_part, 10_000)) assert len(chunks2) == int( np.ceil(tensor.numel() * tensor.element_size() / 10_000)) chunks3 = list( hivemind.utils.split_for_streaming(serialized_tensor_part, 10**9)) assert len(chunks3) == 1 compressed_tensor_part = serialize_torch_tensor(tensor, CompressionType.FLOAT16, allow_inplace=False) chunks4 = list( hivemind.utils.split_for_streaming(compressed_tensor_part, 16384)) assert len(chunks4) == int(np.ceil(tensor.numel() * 2 / 16384)) combined1 = hivemind.utils.combine_from_streaming(chunks1) combined2 = hivemind.utils.combine_from_streaming(iter(chunks2)) combined3 = hivemind.utils.combine_from_streaming(chunks3) combined4 = hivemind.utils.combine_from_streaming(chunks4) for combined in combined1, combined2, combined3: assert torch.allclose(tensor, deserialize_torch_tensor(combined), rtol=1e-5, atol=1e-8) assert torch.allclose(tensor, deserialize_torch_tensor(combined4), rtol=1e-3, atol=1e-3) combined_incomplete = hivemind.utils.combine_from_streaming(chunks4[:5]) combined_incomplete2 = hivemind.utils.combine_from_streaming(chunks4[:1]) combined_incomplete3 = hivemind.utils.combine_from_streaming(chunks4[:-1]) for combined in combined_incomplete, combined_incomplete2, combined_incomplete3: with pytest.raises(RuntimeError): deserialize_torch_tensor(combined)
def handle_add_torch(args): args = MSGPackSerializer.loads(args) tensor = runtime_pb2.Tensor() tensor.ParseFromString(args[0]) result = deserialize_torch_tensor(tensor) for i in range(1, len(args)): tensor = runtime_pb2.Tensor() tensor.ParseFromString(args[i]) result = result + deserialize_torch_tensor(tensor) return serialize_torch_tensor(result).SerializeToString()
async def _communicate_with_peer(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor: """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """ assert self.peer_modes[ self. endpoint] != AveragingMode.AUX, "Auxiliary peers are disallowed from sending tensors" if peer_endpoint == self.endpoint: return await self.accumulate_part( self.endpoint, local_part, weight=self.peer_weights[self.endpoint]) serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False) chunks = split_for_streaming(serialized_tensor_part, self.chunk_size_bytes) stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part() await stream.write( averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING, group_id=self.group_id, endpoint=self.endpoint, tensor_part=next(chunks))) for chunk in chunks: await stream.write(averaging_pb2.AveragingData(tensor_part=chunk)) await stream.done_writing() outputs: Sequence[averaging_pb2.AveragingData] = [ message async for message in stream ] code = outputs[0].code if outputs else averaging_pb2.INTERNAL_ERROR if code != averaging_pb2.AVERAGED_PART: raise AllreduceException( f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)}" f" instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}," f" allreduce failed") try: averaged_part = local_part + deserialize_torch_tensor( combine_from_streaming( [message.tensor_part for message in outputs])) except RuntimeError as e: raise AllreduceException( f"Could not deserialize averaged part from {peer_endpoint}: {e}" ) self.register_averaged_part(peer_endpoint, averaged_part) return averaged_part
async def backward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext): inputs_and_grad_outputs = [ deserialize_torch_tensor(tensor) for tensor in request.tensors ] future = self.experts[request.uid].backward_pool.submit_task( *inputs_and_grad_outputs) serialized_response = [ serialize_torch_tensor(tensor, proto.compression, allow_inplace=True) for tensor, proto in zip( await future, nested_flatten(self.experts[request.uid].grad_inputs_schema)) ] return runtime_pb2.ExpertResponse(tensors=serialized_response)
async def accumulate_part_streaming( self, source: Endpoint, stream_messages: Iterable[runtime_pb2.Tensor] ) -> Iterable[runtime_pb2.Tensor]: """ accumulate_part using streams of serialized tensors. Used to prevent duplicate work in serialization """ try: tensor_part = deserialize_torch_tensor( combine_from_streaming(stream_messages)) except RuntimeError as e: raise AllreduceException( f"Could not deserialize tensor part from {source} for streaming {e}" ) averaged_part = await self.accumulate_part( source, tensor_part, weight=self.peer_weights[source]) serialized_tensor = serialize_torch_tensor(averaged_part - tensor_part, self.compression_type, allow_inplace=False) stream_chunks = tuple( split_for_streaming(serialized_tensor, self.chunk_size_bytes)) return stream_chunks
async def test_call_peer_torch_square(test_input, expected, handler_name="handle"): handle = handle_square_torch server = await P2P.create() await server.add_stream_handler(handler_name, handle) nodes = bootstrap_from([server]) client = await P2P.create(bootstrap=True, bootstrap_peers=nodes) await client.wait_for_at_least_n_peers(1) inp = serialize_torch_tensor(test_input).SerializeToString() result_pb = await client.call_peer_handler(server.id, handler_name, inp) result = runtime_pb2.Tensor() result.ParseFromString(result_pb) result = deserialize_torch_tensor(result) assert torch.allclose(result, expected) await server.stop_listening() await server.shutdown() await client.shutdown()
async def test_call_peer_error(replicate, handler_name="handle"): server_primary = await P2P.create() server = await replicate_if_needed(server_primary, replicate) await server.add_stream_handler(handler_name, handle_add_torch_with_exc) nodes = bootstrap_from([server]) client_primary = await P2P.create(bootstrap=True, bootstrap_peers=nodes) client = await replicate_if_needed(client_primary, replicate) await client.wait_for_at_least_n_peers(1) inp = [ serialize_torch_tensor(i).SerializeToString() for i in [torch.zeros((2, 3)), torch.zeros((3, 2))] ] inp_msgp = MSGPackSerializer.dumps(inp) result = await client.call_peer_handler(server.id, handler_name, inp_msgp) assert result == b'something went wrong :(' await server.stop_listening() await server_primary.shutdown() await client_primary.shutdown()
def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008): torch.manual_seed(0) X = torch.randn(*size) assert torch.allclose( deserialize_torch_tensor( serialize_torch_tensor(X, CompressionType.NONE)), X) error = deserialize_torch_tensor( serialize_torch_tensor(X, CompressionType.MEANSTD_16BIT)) - X assert error.square().mean() < alpha error = deserialize_torch_tensor( serialize_torch_tensor(X, CompressionType.FLOAT16)) - X assert error.square().mean() < alpha error = deserialize_torch_tensor( serialize_torch_tensor(X, CompressionType.QUANTILE_8BIT)) - X assert error.square().mean() < beta error = deserialize_torch_tensor( serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X assert error.square().mean() < beta zeros = torch.zeros(5, 5) for compression_type in CompressionType.values(): assert deserialize_torch_tensor( serialize_torch_tensor(zeros, compression_type)).isfinite().all()
def handle_square_torch(x): tensor = runtime_pb2.Tensor() tensor.ParseFromString(x) tensor = deserialize_torch_tensor(tensor) result = tensor**2 return serialize_torch_tensor(result).SerializeToString()
def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionType) -> float: t = time.time() deserialize_torch_tensor(serialize_torch_tensor(tensor, compression_type)) return time.time() - t
def backward( cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]: assert not torch.is_grad_enabled() (info, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample, detect_anomalies) = ctx._saved_non_tensors alive_ii, alive_jj, *flat_inputs_cpu = ctx.saved_tensors dummy_grad_mask, *flat_grad_outputs = raw_grads flat_grad_outputs_cpu = [] for tensor in flat_grad_outputs: if detect_anomalies and not tensor.isfinite().all(): raise ValueError("One of gradients has nan/inf values") flat_grad_outputs_cpu.append(tensor.cpu()) num_samples, max_experts = dummy_grad_mask.shape inputs_per_expert = zip(*(tensor[alive_ii].split(1, dim=0) for tensor in flat_inputs_cpu)) grad_outputs_per_expert = zip(*(tensor[alive_ii, alive_jj].split(1, dim=0) for tensor in flat_grad_outputs_cpu)) backward_schema = tuple( nested_flatten((info["forward_schema"], info["outputs_schema"]))) # dispatch tasks to all remote experts, collect responses pending_tasks = {} for i, j, inputs_ij, grad_outputs_ij in zip(alive_ii.cpu().numpy(), alive_jj.cpu().numpy(), inputs_per_expert, grad_outputs_per_expert): expert = expert_per_sample[i.item()][j.item()] stub = _get_expert_stub(expert.endpoint) inputs_and_grad_outputs = tuple( nested_flatten((inputs_ij, grad_outputs_ij))) tensors_serialized = [ serialize_torch_tensor(tensor, proto.compression) for tensor, proto in zip(inputs_and_grad_outputs, backward_schema) ] new_task = stub.backward.future( runtime_pb2.ExpertRequest(uid=expert.uid, tensors=tensors_serialized)) pending_tasks[new_task] = (i, j) survivor_inds, survivor_grad_inputs = cls._collect_responses( pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min, detect_anomalies) if len(survivor_inds) < backward_k_min: raise TimeoutError( f"Backward pass: less than {backward_k_min} experts responded within timeout." ) # assemble responses batch_inds, expert_inds = map( lambda x: torch.as_tensor(x, dtype=torch.long), list(zip(*survivor_inds)) or ([], [])) survivor_grad_inputs_stacked = (torch.cat(grad_inputs) for grad_inputs in zip( *survivor_grad_inputs)) # torch tensors, i-th tensor is of shape [num_backward_survivors, *flat_inputs_cpu[i].shape] grad_inputs = nested_map( lambda descr: descr.make_empty( num_samples, device=flat_grad_outputs[0].device).zero_(), list(nested_flatten(info['forward_schema']))) for grad_input, survivor_grad_stacked in zip( grad_inputs, survivor_grad_inputs_stacked): grad_input_per_expert = torch.zeros( # gradient tensor with individual contributions from each expert (num_samples, max_experts, *grad_input.shape[1:]), device=survivor_grad_stacked.device, dtype=survivor_grad_stacked.dtype) grad_input_per_expert[batch_inds, expert_inds] = survivor_grad_stacked grad_input.copy_( grad_input_per_expert.to( flat_grad_outputs[0].device).sum(dim=1)) return (DUMMY, None, None, None, None, None, None, None, None, None, *grad_inputs)
def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int, timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float], detect_anomalies: bool, allow_zero_outputs: bool, info: Dict[str, Any], *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]: assert not torch.is_grad_enabled() num_samples, max_experts = len(experts_per_sample), max( map(len, experts_per_sample)) flat_inputs_cpu = [] for tensor in flat_inputs: if detect_anomalies and not tensor.isfinite().all(): raise ValueError("One of inputs has nan/inf values") flat_inputs_cpu.append(tensor.cpu()) flat_inputs_per_sample = list( zip(*(x.split(1, dim=0) for x in flat_inputs_cpu))) assert len(experts_per_sample) == len( flat_inputs_per_sample) == num_samples # dispatch tasks to all remote experts collect responses pending_tasks: Dict[grpc.Future, Tuple[int, int]] = {} for i in range(num_samples): for j, expert in enumerate(experts_per_sample[i]): input_tensors = [ serialize_torch_tensor(tensor, proto.compression) for tensor, proto in zip( flat_inputs_per_sample[i], nested_flatten(info['forward_schema'])) ] stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub( expert.endpoint) new_task = stub.forward.future( runtime_pb2.ExpertRequest(uid=expert.uid, tensors=input_tensors)) pending_tasks[new_task] = (i, j) responded_inds, alive_flat_outputs = cls._collect_responses( pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min, detect_anomalies) if len(responded_inds) < k_min: raise TimeoutError( f"Forward pass: less than {k_min} responded within timeout.") if not isinstance(info['outputs_schema'], tuple): outputs_schema = (info['outputs_schema'], ) else: outputs_schema = info['outputs_schema'] outputs = nested_map( lambda descriptor: descriptor.make_empty( num_samples, max_experts, device=flat_inputs[0].device).zero_( ), outputs_schema) # assemble responses if len(responded_inds) > 0 or allow_zero_outputs: batch_inds, expert_inds = map( lambda x: torch.as_tensor( x, device=flat_inputs[0].device, dtype=torch.long), list(zip(*responded_inds)) or ([], [])) alive_flat_outputs_stacked = (torch.cat(outputs) for outputs in zip( *alive_flat_outputs)) # torch tensors, i-th tensor is of shape [num_responded, *expert_outputs[i].shape] for output, response_stacked in zip(outputs, alive_flat_outputs_stacked): output[batch_inds, expert_inds] = response_stacked.to(output.device) else: raise RuntimeError( 'Forward pass: 0 experts responded within timeout and allow_zero_outputs is False' ) mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device) mask[batch_inds, expert_inds] = True # save individual outputs for backward pass ctx.save_for_backward(batch_inds, expert_inds, *flat_inputs_cpu) ctx._saved_non_tensors = (info, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample, detect_anomalies) return (mask, ) + outputs