def test_vector_compression(size=(128, 128, 64), alpha=5e-08): torch.manual_seed(0) from hivemind.proto.runtime_pb2 import CompressionType from hivemind.utils import serialize_torch_tensor, deserialize_torch_tensor X = torch.randn(*size) assert torch.allclose( deserialize_torch_tensor( serialize_torch_tensor(X, CompressionType.NONE)), X) error = deserialize_torch_tensor( serialize_torch_tensor(X, CompressionType.MEANSTD_LAST_AXIS_FLOAT16)) - X assert error.square().mean() < alpha error = deserialize_torch_tensor( serialize_torch_tensor(X, CompressionType.FLOAT16)) - X assert error.square().mean() < alpha
async def _communicate_with_peer(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor: """ Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """ serialized_tensor_part = serialize_torch_tensor(local_part, self.compression_type, allow_inplace=False) chunks = split_for_streaming(serialized_tensor_part, self.chunk_size_bytes) stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part() await stream.write( averaging_pb2.AveragingData(code=averaging_pb2.PART_FOR_AVERAGING, group_id=self.group_id, endpoint=self.endpoint, tensor_part=next(chunks))) for chunk in chunks: await stream.write(averaging_pb2.AveragingData(tensor_part=chunk)) await stream.done_writing() outputs: Sequence[averaging_pb2.AveragingData] = [ message async for message in stream ] code = outputs[0].code if outputs else averaging_pb2.INTERNAL_ERROR if code != averaging_pb2.AVERAGED_PART: raise AllreduceException( f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)}" f" instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}," f" allreduce failed") averaged_part = deserialize_torch_tensor( combine_from_streaming( [message.tensor_part for message in outputs])) self.register_averaged_part(peer_endpoint, averaged_part) return averaged_part
def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008): torch.manual_seed(0) X = torch.randn(*size) assert torch.allclose( deserialize_torch_tensor( serialize_torch_tensor(X, CompressionType.NONE)), X) error = deserialize_torch_tensor( serialize_torch_tensor(X, CompressionType.MEANSTD_LAST_AXIS_FLOAT16)) - X assert error.square().mean() < alpha error = deserialize_torch_tensor( serialize_torch_tensor(X, CompressionType.FLOAT16)) - X assert error.square().mean() < alpha error = deserialize_torch_tensor( serialize_torch_tensor(X, CompressionType.QUANTILE_8BIT)) - X assert error.square().mean() < beta
async def forward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext): inputs = [ deserialize_torch_tensor(tensor) for tensor in request.tensors ] future = self.experts[request.uid].forward_pool.submit_task(*inputs) serialized_response = [ serialize_torch_tensor(tensor) for tensor in await future ] return runtime_pb2.ExpertResponse(tensors=serialized_response)
def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]: assert not torch.is_grad_enabled() (info, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample, detect_anomalies) = ctx._saved_non_tensors alive_ii, alive_jj, *flat_inputs_cpu = ctx.saved_tensors dummy_grad_mask, *flat_grad_outputs = raw_grads flat_grad_outputs_cpu = [] for tensor in flat_grad_outputs: if detect_anomalies and not tensor.isfinite().all(): raise ValueError("One of gradients has nan/inf values") flat_grad_outputs_cpu.append(tensor.cpu()) num_samples, max_experts = dummy_grad_mask.shape inputs_per_expert = zip(*(tensor[alive_ii].split(1, dim=0) for tensor in flat_inputs_cpu)) grad_outputs_per_expert = zip(*(tensor[alive_ii, alive_jj].split(1, dim=0) for tensor in flat_grad_outputs_cpu)) backward_schema = tuple(nested_flatten((info["forward_schema"], info["outputs_schema"]))) # dispatch tasks to all remote experts, collect responses pending_tasks = {} for i, j, inputs_ij, grad_outputs_ij in zip(alive_ii.cpu().numpy(), alive_jj.cpu().numpy(), inputs_per_expert, grad_outputs_per_expert): expert = expert_per_sample[i.item()][j.item()] stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint) inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij))) tensors_serialized = [serialize_torch_tensor(tensor, proto.compression) for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)] new_task = stub.backward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=tensors_serialized)) pending_tasks[new_task] = (i, j) backward_survivor_indices, survivor_grad_inputs = cls._collect_responses( pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min, detect_anomalies) if len(backward_survivor_indices) == 0: raise TimeoutError("Backward pass: no alive experts responded within timeout.") # assemble responses backward_survivor_ii, backward_survivor_jj = map(torch.as_tensor, zip(*backward_survivor_indices) or ([], [])) survivor_grad_inputs_stacked = (torch.cat(grad_inputs) for grad_inputs in zip(*survivor_grad_inputs)) # torch tensors, i-th tensor is of shape [num_backward_survivors, *flat_inputs_cpu[i].shape] grad_inputs = [] for i, survivor_grad_stacked in enumerate(survivor_grad_inputs_stacked): grad_input_per_expert = torch.zeros( # gradient tensor with individual contributions from each expert (num_samples, max_experts, *flat_inputs_cpu[i].shape[1:]), device=survivor_grad_stacked.device, dtype=survivor_grad_stacked.dtype) grad_input_per_expert[backward_survivor_ii, backward_survivor_jj] = survivor_grad_stacked # sum gradients from each expert grad_inputs.append(grad_input_per_expert.to(flat_grad_outputs[0].device).sum(dim=1)) return (DUMMY, None, None, None, None, None, None, None, None, *grad_inputs)
async def accumulate_part_streaming(self, source: Endpoint, stream_messages: Iterable[runtime_pb2.Tensor] ) -> Iterable[runtime_pb2.Tensor]: """ accumulate_part using streams of serialized tensors. Used to prevent duplicate work in serialization """ tensor_part: torch.Tensor = deserialize_torch_tensor(combine_from_streaming(stream_messages)) averaged_part = await self.accumulate_part(source, tensor_part) if not self.averaged_part_stream.done(): serialized_tensor = serialize_torch_tensor(averaged_part, self.compression_type, allow_inplace=False) stream_chunks = tuple(split_for_streaming(serialized_tensor, self.chunk_size_bytes)) self.averaged_part_stream.set_result(stream_chunks) return stream_chunks else: return self.averaged_part_stream.result()
def forward(cls, ctx, dummy, experts_per_sample: List[List[RemoteExpert]], k_min: int, backward_k_min: int, timeout_after_k_min: float, forward_timeout: Optional[float], backward_timeout: Optional[float], detect_anomalies: bool, info: Dict[str, Any], *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]: assert not torch.is_grad_enabled() num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample)) flat_inputs_cpu = [] for tensor in flat_inputs: if detect_anomalies and not tensor.isfinite().all(): raise ValueError("One of inputs has nan/inf values") flat_inputs_cpu.append(tensor.cpu()) flat_inputs_per_sample = list(zip(*(x.split(1, dim=0) for x in flat_inputs_cpu))) assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples # dispatch tasks to all remote experts collect responses pending_tasks: Dict[grpc.Future, Tuple[int, int]] = {} for i in range(num_samples): for j, expert in enumerate(experts_per_sample[i]): input_tensors = [serialize_torch_tensor(tensor, proto.compression) for tensor, proto in zip( flat_inputs_per_sample[i], nested_flatten(info['forward_schema']))] stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint) new_task = stub.forward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=input_tensors)) pending_tasks[new_task] = (i, j) alive_grid_indices, alive_flat_outputs = cls._collect_responses( pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min, detect_anomalies) if len(alive_grid_indices) == 0: raise TimeoutError("Forward pass: no alive experts responded within timeout.") # assemble responses alive_ii, alive_jj = map(torch.as_tensor, zip(*alive_grid_indices)) mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device) mask[alive_ii, alive_jj] = True alive_flat_outputs_stacked = (torch.cat(outputs) for outputs in zip(*alive_flat_outputs)) # torch tensors, i-th tensor is of shape [num_responded, *expert_outputs[i].shape] outputs = [] for response_stacked in alive_flat_outputs_stacked: output = torch.zeros( [num_samples, max_experts, *response_stacked.shape[1:]], device=response_stacked.device, dtype=response_stacked.dtype, requires_grad=response_stacked.requires_grad) output[alive_ii, alive_jj] = response_stacked outputs.append(output.to(flat_inputs[0].device)) # save individual outputs for backward pass ctx.save_for_backward(alive_ii, alive_jj, *flat_inputs_cpu) ctx._saved_non_tensors = (info, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample, detect_anomalies) return (mask,) + tuple(outputs)
async def backward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext): inputs_and_grad_outputs = [ deserialize_torch_tensor(tensor) for tensor in request.tensors ] future = self.experts[request.uid].backward_pool.submit_task( *inputs_and_grad_outputs) serialized_response = [ serialize_torch_tensor(tensor, proto.compression, allow_inplace=True) for tensor, proto in zip( await future, nested_flatten(self.experts[request.uid].grad_inputs_schema)) ] return runtime_pb2.ExpertResponse(tensors=serialized_response)
async def _forward_one_expert(grid_indices: Tuple[int, ...], expert: RemoteExpert, inputs: Tuple[torch.Tensor]): stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub( expert.endpoint, aio=True) try: outputs = await stub.forward( runtime_pb2.ExpertRequest(uid=expert.uid, tensors=[ serialize_torch_tensor(tensor) for tensor in inputs ])) return grid_indices, tuple( deserialize_torch_tensor(tensor) for tensor in outputs.tensors) except grpc.experimental.aio.AioRpcError as error: logger.warning( f"RemoteExpert {expert} failed forward: {error.code()} (inputs: {inputs})" )
async def accumulate_part_streaming( self, source: Endpoint, stream_messages: Iterable[runtime_pb2.Tensor] ) -> Iterable[runtime_pb2.Tensor]: """ accumulate_part using streams of serialized tensors. Used to prevent duplicate work in serialization """ try: tensor_part = deserialize_torch_tensor( combine_from_streaming(stream_messages)) except RuntimeError as e: raise AllreduceException( f"Could not deserialize tensor part from {source} for streaming {e}" ) averaged_part = await self.accumulate_part( source, tensor_part, weight=self.peer_weights[source]) if not self.averaged_part_stream.done(): serialized_tensor = serialize_torch_tensor(averaged_part, self.compression_type, allow_inplace=False) stream_chunks = tuple( split_for_streaming(serialized_tensor, self.chunk_size_bytes)) self.averaged_part_stream.set_result(stream_chunks) return stream_chunks else: return self.averaged_part_stream.result()
def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionType) -> float: t = time.time() deserialize_torch_tensor(serialize_torch_tensor(tensor, compression_type)) return time.time() - t