async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext ) -> AsyncIterator[averaging_pb2.MessageFromLeader]: """ accept or reject a join request from another averager; if accepted, run him through allreduce steps """ try: async with self.lock_request_join_group: reason_to_reject = self._check_reasons_to_reject(request) if reason_to_reject is not None: yield reason_to_reject return self.current_followers[request.endpoint] = request yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED) if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done(): # outcome 1: we have assembled a full group and are ready for allreduce await self.leader_assemble_group() # wait for the group to be assembled or disbanded timeout = max(0.0, self.potential_leaders.declared_expiration_time - get_dht_time()) await asyncio.wait({self.assembled_group, self.was_accepted_to_group.wait()}, return_when=asyncio.FIRST_COMPLETED, timeout=timeout) if not self.assembled_group.done() and not self.was_accepted_to_group.is_set(): async with self.lock_request_join_group: if self.assembled_group.done(): pass # this covers a rare case when the group is assembled while the event loop was busy. elif len(self.current_followers) + 1 >= self.min_group_size and self.is_looking_for_group: # outcome 2: the time is up, run allreduce with what we have or disband await self.leader_assemble_group() else: await self.leader_disband_group() if self.was_accepted_to_group.is_set() or not self.assembled_group.done() \ or self.assembled_group.cancelled() or request.endpoint not in self.assembled_group.result(): if self.current_leader is not None: # outcome 3: found by a leader with higher priority, send our followers to him yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED, suggested_leader=self.current_leader) return else: yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED) return allreduce_group = self.assembled_group.result() yield averaging_pb2.MessageFromLeader( code=averaging_pb2.BEGIN_ALLREDUCE, group_id=allreduce_group.group_id, ordered_group_endpoints=allreduce_group.ordered_group_endpoints, part_sizes=allreduce_group.part_sizes, gathered=allreduce_group.gathered, group_key_seed=allreduce_group.group_key_seed) except (concurrent.futures.CancelledError, asyncio.CancelledError): return # note: this is a compatibility layer for python3.7 except Exception as e: logger.exception(e) yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR) finally: # note: this code is guaranteed to run even if the coroutine is destroyed prematurely self.current_followers.pop(request.endpoint, None) self.follower_was_discarded.set()
def _check_reasons_to_reject(self, request: averaging_pb2.JoinRequest) -> Optional[averaging_pb2.MessageFromLeader]: """ :returns: if accepted, return None, otherwise return a reason for rejection """ if not self.is_looking_for_group or self.assembled_group.done(): return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_LOOKING_FOR_GROUP) if request.ListFields() == 3 and not isinstance(request.schema_hash, bytes) or len(request.schema_hash) == 0 \ or not isinstance(request.expiration, DHTExpiration) or not isfinite(request.expiration) \ or not isinstance(request.endpoint, Endpoint) or len(request.endpoint) == 0 or self.client_mode: return averaging_pb2.MessageFromLeader(code=averaging_pb2.PROTOCOL_VIOLATION) elif request.schema_hash != self.schema_hash: return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_SCHEMA_HASH) elif self.potential_leaders.declared_group_key is None: return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_DECLARED) elif self.potential_leaders.declared_expiration_time > (request.expiration or float('inf')): return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_EXPIRATION_TIME) elif self.current_leader is not None: return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_A_LEADER, suggested_leader=self.current_leader ) # note: this suggested leader is currently ignored elif request.endpoint == self.endpoint or request.endpoint in self.current_followers: return averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_ENDPOINT) elif len(self.current_followers) + 1 >= self.target_group_size: return averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_IS_FULL) else: return None