async def follower_assemble_group( self, leader: Endpoint, msg: averaging_pb2.MessageFromLeader) -> AllReduceRunner: """ Prepare to run allreduce using a list of peers provided by our leader """ assert self.lock_looking_for_group.locked( ) and self.lock_request_join_group.locked() assert not self.assembled_group.done() assert self.current_leader == leader, f"averager does not follow {leader} (actual: {self.current_leader})" group_id, ordered_group_endpoints, part_sizes = msg.group_id, msg.ordered_group_endpoints, msg.part_sizes assert self.endpoint in ordered_group_endpoints, "Leader sent us group_endpoints that does not contain us!" assert len(ordered_group_endpoints) == len(part_sizes) == len( msg.gathered) logger.debug( f"{self.endpoint} - follower started allreduce after being prompted by leader {leader}." ) allreduce_group = AllReduceRunner( group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint, ordered_group_endpoints=tuple(ordered_group_endpoints), part_sizes=tuple(part_sizes), gathered=msg.gathered, **self.allreduce_kwargs) self.assembled_group.set_result(allreduce_group) return allreduce_group
async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner: """ Use a group description found by Matchmaking to form AllreduceRunner """ try: weights, throughputs, modes, user_gathered = zip( *map(self.serializer.loads, group_info.gathered)) user_gathered = dict( zip(group_info.endpoints, map(self.serializer.loads, user_gathered))) # compute optimal part sizes from peer throughputs incoming_throughputs = [ thr if listen else 0.0 for thr, listen in zip(throughputs, modes) ] part_sizes = await asyncio.get_event_loop().run_in_executor( None, load_balance_peers, self.total_size, incoming_throughputs, min_vector_size) async with self.get_tensors_async() as averaged_tensors: return AllReduceRunner( group_id=group_info.group_id, tensors=averaged_tensors, endpoint=self.endpoint, ordered_group_endpoints=group_info.endpoints, part_sizes=part_sizes, weights=weights, gathered=user_gathered, return_deltas=True, **kwargs) except Exception as e: raise MatchmakingException( f"Unable to create allreduce runner ({e}), group_info: {group_info}" )
async def leader_assemble_group(self) -> AllReduceRunner: """ Form up all current followers into a group and prepare to _run_allreduce """ assert self.lock_looking_for_group.locked() and self.lock_request_join_group.locked() assert not self.assembled_group.done() group_id = DHTID.generate().to_bytes() ordered_group_endpoints = list(self.current_followers) ordered_group_endpoints.append(self.endpoint) random.shuffle(ordered_group_endpoints) throughputs, gathered = [], [] for endpoint in ordered_group_endpoints: if endpoint == self.endpoint: throughputs.append(self.throughput) gathered.append(self.data_for_gather) else: follower_info = self.current_followers[endpoint] throughputs.append(follower_info.throughput if follower_info.throughput >= 0 else None) gathered.append(follower_info.gather if follower_info.gather else None) part_sizes = load_balance_peers(self.total_size, throughputs, self.min_vector_size) group_key_seed = random.randint(- 2 ** 31, 2 ** 31 - 1) logger.debug(f"{self.endpoint} - leader started allreduce for {len(ordered_group_endpoints)} peers.") allreduce_group = AllReduceRunner(group_id=group_id, tensors=self.averaged_tensors, endpoint=self.endpoint, ordered_group_endpoints=ordered_group_endpoints, part_sizes=part_sizes, gathered=gathered, group_key_seed=group_key_seed, **self.allreduce_kwargs) await self.group_key_manager.update_key_on_group_assembled(allreduce_group, is_leader=True) self.assembled_group.set_result(allreduce_group) return allreduce_group