예제 #1
0
    async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
                             ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
        """ accept or reject a join request from another averager; if accepted, run him through allreduce steps """
        try:
            async with self.lock_request_join_group:
                reason_to_reject = self._check_reasons_to_reject(request)
                if reason_to_reject is not None:
                    yield reason_to_reject
                    return

                self.current_followers[request.endpoint] = request
                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)

                if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done():
                    # outcome 1: we have assembled a full group and are ready for allreduce
                    await self.leader_assemble_group()

            # wait for the group to be assembled or disbanded
            timeout = max(0.0, self.potential_leaders.declared_expiration_time - get_dht_time())
            await asyncio.wait({self.assembled_group, self.was_accepted_to_group.wait()},
                               return_when=asyncio.FIRST_COMPLETED, timeout=timeout)
            if not self.assembled_group.done() and not self.was_accepted_to_group.is_set():
                async with self.lock_request_join_group:
                    if self.assembled_group.done():
                        pass  # this covers a rare case when the group is assembled while the event loop was busy.
                    elif len(self.current_followers) + 1 >= self.min_group_size and self.is_looking_for_group:
                        # outcome 2: the time is up, run allreduce with what we have or disband
                        await self.leader_assemble_group()
                    else:
                        await self.leader_disband_group()

            if self.was_accepted_to_group.is_set() or not self.assembled_group.done() \
                    or self.assembled_group.cancelled() or request.endpoint not in self.assembled_group.result():
                if self.current_leader is not None:
                    # outcome 3: found by a leader with higher priority, send our followers to him
                    yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED,
                                                          suggested_leader=self.current_leader)
                    return
                else:
                    yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED)
                    return

            allreduce_group = self.assembled_group.result()
            yield averaging_pb2.MessageFromLeader(
                code=averaging_pb2.BEGIN_ALLREDUCE, group_id=allreduce_group.group_id,
                ordered_group_endpoints=allreduce_group.ordered_group_endpoints, part_sizes=allreduce_group.part_sizes,
                gathered=allreduce_group.gathered, group_key_seed=allreduce_group.group_key_seed)
        except (concurrent.futures.CancelledError, asyncio.CancelledError):
            return  # note: this is a compatibility layer for python3.7
        except Exception as e:
            logger.exception(e)
            yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR)

        finally:  # note: this code is guaranteed to run even if the coroutine is destroyed prematurely
            self.current_followers.pop(request.endpoint, None)
            self.follower_was_discarded.set()
예제 #2
0
    def _check_reasons_to_reject(self, request: averaging_pb2.JoinRequest) -> Optional[averaging_pb2.MessageFromLeader]:
        """ :returns: if accepted, return None, otherwise return a reason for rejection """
        if not self.is_looking_for_group or self.assembled_group.done():
            return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_LOOKING_FOR_GROUP)

        if request.ListFields() == 3 and not isinstance(request.schema_hash, bytes) or len(request.schema_hash) == 0 \
                or not isinstance(request.expiration, DHTExpiration) or not isfinite(request.expiration) \
                or not isinstance(request.endpoint, Endpoint) or len(request.endpoint) == 0 or self.client_mode:
            return averaging_pb2.MessageFromLeader(code=averaging_pb2.PROTOCOL_VIOLATION)

        elif request.schema_hash != self.schema_hash:
            return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_SCHEMA_HASH)
        elif self.potential_leaders.declared_group_key is None:
            return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_DECLARED)
        elif self.potential_leaders.declared_expiration_time > (request.expiration or float('inf')):
            return averaging_pb2.MessageFromLeader(code=averaging_pb2.BAD_EXPIRATION_TIME)
        elif self.current_leader is not None:
            return averaging_pb2.MessageFromLeader(code=averaging_pb2.NOT_A_LEADER, suggested_leader=self.current_leader
                                                   )  # note: this suggested leader is currently ignored
        elif request.endpoint == self.endpoint or request.endpoint in self.current_followers:
            return averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_ENDPOINT)
        elif len(self.current_followers) + 1 >= self.target_group_size:
            return averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_IS_FULL)
        else:
            return None