Beispiel #1
0
    async def _declare_averager_periodically(self,
                                             key_manager: GroupKeyManager):
        async with self.lock_declare:
            try:
                while True:
                    await self.running.wait()

                    new_expiration_time = min(
                        get_dht_time() + self.averaging_expiration,
                        self.search_end_time)
                    self.declared_group_key = group_key = key_manager.current_key
                    self.declared_expiration_time = new_expiration_time
                    self.declared_expiration.set()
                    await key_manager.declare_averager(
                        group_key,
                        self.endpoint,
                        expiration_time=new_expiration_time)
                    await asyncio.sleep(self.declared_expiration_time -
                                        get_dht_time())
                    if self.running.is_set() and len(self.leader_queue) == 0:
                        await key_manager.update_key_on_not_enough_peers()

            except Exception as e:  # note: we catch exceptions here because otherwise they are never printed
                logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
            finally:
                if self.declared_group_key is not None:
                    prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time
                    self.declared_group_key, self.declared_expiration_time = None, float(
                        'inf')
                    self.leader_queue, self.max_assured_time = TimedStorage[
                        Endpoint, DHTExpiration](), float('-inf')
                    await key_manager.declare_averager(prev_declared_key,
                                                       self.endpoint,
                                                       prev_expiration_time,
                                                       looking_for_group=False)
Beispiel #2
0
    async def _update_queue_periodically(self, key_manager: GroupKeyManager):
        try:
            DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
            while get_dht_time() < self.search_end_time:
                new_peers = await key_manager.get_averagers(
                    key_manager.current_key, only_active=True)
                self.max_assured_time = max(
                    self.max_assured_time,
                    get_dht_time() + self.averaging_expiration - DISCREPANCY)

                self.leader_queue.clear()
                for peer, peer_expiration_time in new_peers:
                    if peer == self.endpoint or (
                            peer, peer_expiration_time) in self.past_attempts:
                        continue
                    self.leader_queue.store(peer, peer_expiration_time,
                                            peer_expiration_time)
                    self.max_assured_time = max(
                        self.max_assured_time,
                        peer_expiration_time - DISCREPANCY)

                self.update_finished.set()

                await asyncio.wait(
                    {self.running.wait(),
                     self.update_triggered.wait()},
                    return_when=asyncio.ALL_COMPLETED,
                    timeout=self.search_end_time -
                    get_dht_time() if isfinite(self.search_end_time) else None)
                self.update_triggered.clear()
        except (concurrent.futures.CancelledError, asyncio.CancelledError):
            return  # note: this is a compatibility layer for python3.7
        except Exception as e:
            logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
            raise
Beispiel #3
0
    async def _update_queue_periodically(self, group_key: GroupKey):
        DISCREPANCY = hivemind.utils.timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
        while get_dht_time() < self.search_end_time:
            new_peers = await self.dht.get_averagers(group_key,
                                                     only_active=True,
                                                     return_future=True)
            self.max_assured_time = max(
                self.max_assured_time,
                get_dht_time() + self.averaging_expiration - DISCREPANCY)

            self.leader_queue.clear()
            for peer, peer_expiration_time in new_peers:
                if peer == self.endpoint or (
                        peer, peer_expiration_time) in self.past_attempts:
                    continue
                self.leader_queue.store(peer, peer_expiration_time,
                                        peer_expiration_time)
                self.max_assured_time = max(self.max_assured_time,
                                            peer_expiration_time - DISCREPANCY)

            self.update_finished.set()

            await asyncio.wait(
                {self.running.wait(),
                 self.update_triggered.wait()},
                return_when=asyncio.ALL_COMPLETED,
                timeout=self.search_end_time -
                get_dht_time() if isfinite(self.search_end_time) else None)
            self.update_triggered.clear()
Beispiel #4
0
    async def _declare_averager_periodically(self, group_key: GroupKey):
        async with self.lock_declare:
            try:
                while True:
                    await self.running.wait()

                    new_expiration_time = min(
                        get_dht_time() + self.averaging_expiration,
                        self.search_end_time)
                    self.declared_group_key, self.declared_expiration_time = group_key, new_expiration_time
                    self.declared_expiration.set()
                    await self.dht.declare_averager(group_key,
                                                    self.endpoint,
                                                    new_expiration_time,
                                                    looking_for_group=True,
                                                    return_future=True)
                    await asyncio.sleep(self.declared_expiration_time -
                                        get_dht_time())
            except Exception as e:  # note: we catch exceptions here because otherwise they are never printed
                logger.error(f"{self.endpoint} - caught {type(e)}: {e}")
            finally:
                if self.declared_group_key is not None:
                    prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time
                    self.declared_group_key, self.declared_expiration_time = None, float(
                        'inf')
                    self.leader_queue, self.max_assured_time = TimedStorage[
                        Endpoint, DHTExpiration](), float('-inf')
                    await self.dht.declare_averager(prev_declared_key,
                                                    self.endpoint,
                                                    prev_expiration_time,
                                                    looking_for_group=False,
                                                    return_future=True)
def test_sending_validator_instance_between_processes():
    alice = hivemind.DHT(start=True)
    bob = hivemind.DHT(start=True, initial_peers=[f"{LOCALHOST}:{alice.port}"])

    alice.add_validators([SchemaValidator(SampleSchema)])
    bob.add_validators([SchemaValidator(SampleSchema)])

    assert bob.store('experiment_name', b'foo_bar', get_dht_time() + 10)
    assert not bob.store('experiment_name', 777, get_dht_time() + 10)
    assert alice.get('experiment_name', latest=True).value == b'foo_bar'
async def test_expecting_public_keys(dht_nodes_with_schema):
    alice, bob = dht_nodes_with_schema

    # Subkeys expected to contain a public key
    # (so hivemind.dht.crypto.RSASignatureValidator would require a signature)
    assert await bob.store('signed_data', b'foo_bar', get_dht_time() + 10,
                           subkey=b'uid[owner:public-key]')
    assert not await bob.store('signed_data', b'foo_bar', get_dht_time() + 10,
                               subkey=b'uid-without-public-key')

    for peer in [alice, bob]:
        dictionary = (await peer.get('signed_data', latest=True)).value
        assert (len(dictionary) == 1 and
                dictionary[b'uid[owner:public-key]'].value == b'foo_bar')
async def test_expecting_regular_value(dht_nodes_with_schema):
    alice, bob = dht_nodes_with_schema

    # Regular value (bytes) expected
    assert await bob.store('experiment_name', b'foo_bar', get_dht_time() + 10)
    assert not await bob.store('experiment_name', 666, get_dht_time() + 10)
    assert not await bob.store('experiment_name', b'foo_bar', get_dht_time() + 10,
                               subkey=b'subkey')

    # Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
    assert not await bob.store('experiment_name', [], get_dht_time() + 10)
    assert not await bob.store('experiment_name', [1, 2, 3], get_dht_time() + 10)

    for peer in [alice, bob]:
        assert (await peer.get('experiment_name', latest=True)).value == b'foo_bar'
Beispiel #8
0
 def request_expiration_time(self) -> float:
     """ this averager's current expiration time - used to send join requests to leaders """
     if isfinite(self.declared_expiration_time):
         return self.declared_expiration_time
     else:
         return min(get_dht_time() + self.averaging_expiration,
                    self.search_end_time)
Beispiel #9
0
    async def request_join_group(self, leader: Endpoint, expiration_time: DHTExpiration) -> Optional[GroupInfo]:
        """
        :param leader: request this peer to be your leader for allreduce
        :param expiration_time: inform leader that we intend to begin averaging before this expiration_time
        :returns: if leader leader accepted us and started AllReduce, return that AllReduce. Otherwise, return None
        :note: this function does not guarantee that your group leader is the same as :leader: parameter
          The originally specified leader can disband group and redirect us to a different leader
        """
        assert self.is_looking_for_group and self.current_leader is None
        call: Optional[grpc.aio.UnaryStreamCall] = None
        try:
            async with self.lock_request_join_group:
                leader_stub = ChannelCache.get_stub(leader, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
                call = leader_stub.rpc_join_group(averaging_pb2.JoinRequest(
                    endpoint=self.endpoint, schema_hash=self.schema_hash, expiration=expiration_time,
                    client_mode=self.client_mode, gather=self.data_for_gather))
                message = await asyncio.wait_for(call.read(), timeout=self.request_timeout)

                if message.code == averaging_pb2.ACCEPTED:
                    logger.debug(f"{self.endpoint} - joining the group of {leader}; waiting for peers")
                    self.current_leader = leader
                    self.was_accepted_to_group.set()
                    if len(self.current_followers) > 0:
                        await self.leader_disband_group()

            if message.code != averaging_pb2.ACCEPTED:
                code = averaging_pb2.MessageCode.Name(message.code)
                logger.debug(f"{self.endpoint} - requested {leader} to be my leader, but got rejected with {code}")
                return None

            async with self.potential_leaders.pause_search():
                time_to_expiration = max(expiration_time - get_dht_time(), 0.0)
                message = await asyncio.wait_for(call.read(), time_to_expiration + self.request_timeout)

                if message.code == averaging_pb2.BEGIN_ALLREDUCE:
                    async with self.lock_request_join_group:
                        return await self.follower_assemble_group(leader, message)

            if message.code in (averaging_pb2.GROUP_DISBANDED, averaging_pb2.CANCELLED):
                if message.suggested_leader and message.suggested_leader != self.endpoint:
                    logger.debug(f"{self} - leader disbanded group and redirected us to {message.suggested_leader}")
                    self.current_leader = None
                    call.cancel()
                    return await self.request_join_group(message.suggested_leader, expiration_time)
                else:
                    logger.debug(f"{self} - leader disbanded group")
                    return None

            logger.debug(f"{self} - unexpected message from leader: {averaging_pb2.MessageCode.Name(message.code)}")
            return None
        except asyncio.TimeoutError:
            logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
            if call is not None:
                call.cancel()
            return None
        finally:
            self.was_accepted_to_group.clear()
            self.current_leader = None
            if call is not None:
                await call.code()
Beispiel #10
0
    async def begin_search(self,
                           key_manager: GroupKeyManager,
                           timeout: Optional[float],
                           declare: bool = True):
        async with self.lock_search:
            self.running.set()
            self.search_end_time = get_dht_time(
            ) + timeout if timeout is not None else float('inf')
            update_queue_task = asyncio.create_task(
                self._update_queue_periodically(key_manager))
            if declare:
                declare_averager_task = asyncio.create_task(
                    self._declare_averager_periodically(key_manager))

            try:
                yield self
            finally:
                if not update_queue_task.done():
                    update_queue_task.cancel()
                if declare and not declare_averager_task.done():
                    declare_averager_task.cancel()

                for field in (self.past_attempts, self.leader_queue,
                              self.running, self.update_finished,
                              self.update_triggered, self.declared_expiration):
                    field.clear()
                self.max_assured_time = float('-inf')
                self.search_end_time = float('inf')
Beispiel #11
0
async def test_keys_outside_schema(dht_nodes_with_schema):
    class Schema(BaseModel):
        some_field: StrictInt

    class MergedSchema(BaseModel):
        another_field: StrictInt

    for allow_extra_keys in [False, True]:
        validator = SchemaValidator(Schema, allow_extra_keys=allow_extra_keys)
        assert validator.merge_with(
            SchemaValidator(MergedSchema, allow_extra_keys=False))

        alice = await DHTNode.create(record_validator=validator)
        bob = await DHTNode.create(record_validator=validator,
                                   initial_peers=[f"{LOCALHOST}:{alice.port}"])

        store_ok = await bob.store('unknown_key', b'foo_bar',
                                   get_dht_time() + 10)
        assert store_ok == allow_extra_keys

        for peer in [alice, bob]:
            result = await peer.get('unknown_key', latest=True)
            if allow_extra_keys:
                assert result.value == b'foo_bar'
            else:
                assert result is None
Beispiel #12
0
def test_rsa_signature_validator():
    receiver_validator = RSASignatureValidator()
    sender_validator = RSASignatureValidator(RSAPrivateKey())
    mallory_validator = RSASignatureValidator(RSAPrivateKey())

    plain_record = DHTRecord(key=b'key', subkey=b'subkey', value=b'value',
                             expiration_time=get_dht_time() + 10)
    protected_records = [
        dataclasses.replace(plain_record,
                            key=plain_record.key + sender_validator.local_public_key),
        dataclasses.replace(plain_record,
                            subkey=plain_record.subkey + sender_validator.local_public_key),
    ]

    # test 1: Non-protected record (no signature added)
    assert sender_validator.sign_value(plain_record) == plain_record.value
    assert receiver_validator.validate(plain_record)

    # test 2: Correct signatures
    signed_records = [dataclasses.replace(record, value=sender_validator.sign_value(record))
                      for record in protected_records]
    for record in signed_records:
        assert receiver_validator.validate(record)
        assert receiver_validator.strip_value(record) == b'value'

    # test 3: Invalid signatures
    signed_records = protected_records  # Without signature
    signed_records += [dataclasses.replace(record,
                                           value=record.value + b'[signature:INVALID_BYTES]')
                       for record in protected_records]  # With invalid signature
    signed_records += [dataclasses.replace(record, value=mallory_validator.sign_value(record))
                       for record in protected_records]  # With someone else's signature
    for record in signed_records:
        assert not receiver_validator.validate(record)
Beispiel #13
0
async def test_prefix():
    class Schema(BaseModel):
        field: StrictInt

    validator = SchemaValidator(Schema, allow_extra_keys=False, prefix='prefix')

    alice = await DHTNode.create(record_validator=validator)
    bob = await DHTNode.create(
        record_validator=validator, initial_peers=[f"{LOCALHOST}:{alice.port}"])

    assert await bob.store('prefix_field', 777, get_dht_time() + 10)
    assert not await bob.store('prefix_field', 'string_value', get_dht_time() + 10)
    assert not await bob.store('field', 777, get_dht_time() + 10)

    for peer in [alice, bob]:
        assert (await peer.get('prefix_field', latest=True)).value == 777
        assert (await peer.get('field', latest=True)) is None
Beispiel #14
0
    async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
                             ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
        """ accept or reject a join request from another averager; if accepted, run him through allreduce steps """
        try:
            async with self.lock_request_join_group:
                reason_to_reject = self._check_reasons_to_reject(request)
                if reason_to_reject is not None:
                    yield reason_to_reject
                    return

                self.current_followers[request.endpoint] = request
                yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)

                if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done():
                    # outcome 1: we have assembled a full group and are ready for allreduce
                    await self.leader_assemble_group()

            # wait for the group to be assembled or disbanded
            timeout = max(0.0, self.potential_leaders.declared_expiration_time - get_dht_time())
            await asyncio.wait({self.assembled_group, self.was_accepted_to_group.wait()},
                               return_when=asyncio.FIRST_COMPLETED, timeout=timeout)
            if not self.assembled_group.done() and not self.was_accepted_to_group.is_set():
                async with self.lock_request_join_group:
                    if self.assembled_group.done():
                        pass  # this covers a rare case when the group is assembled while the event loop was busy.
                    elif len(self.current_followers) + 1 >= self.min_group_size and self.is_looking_for_group:
                        # outcome 2: the time is up, run allreduce with what we have or disband
                        await self.leader_assemble_group()
                    else:
                        await self.leader_disband_group()

            if self.was_accepted_to_group.is_set() or not self.assembled_group.done() \
                    or self.assembled_group.cancelled() or request.endpoint not in self.assembled_group.result():
                if self.current_leader is not None:
                    # outcome 3: found by a leader with higher priority, send our followers to him
                    yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED,
                                                          suggested_leader=self.current_leader)
                    return
                else:
                    yield averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_DISBANDED)
                    return

            allreduce_group = self.assembled_group.result()
            yield averaging_pb2.MessageFromLeader(
                code=averaging_pb2.BEGIN_ALLREDUCE, group_id=allreduce_group.group_id,
                ordered_group_endpoints=allreduce_group.ordered_group_endpoints, part_sizes=allreduce_group.part_sizes,
                gathered=allreduce_group.gathered, group_key_seed=allreduce_group.group_key_seed)
        except (concurrent.futures.CancelledError, asyncio.CancelledError):
            return  # note: this is a compatibility layer for python3.7
        except Exception as e:
            logger.exception(e)
            yield averaging_pb2.MessageFromLeader(code=averaging_pb2.INTERNAL_ERROR)

        finally:  # note: this code is guaranteed to run even if the coroutine is destroyed prematurely
            self.current_followers.pop(request.endpoint, None)
            self.follower_was_discarded.set()
Beispiel #15
0
def test_signing_in_different_process():
    parent_conn, child_conn = mp.Pipe()
    process = mp.Process(target=get_signed_record, args=[child_conn])
    process.start()

    validator = RSASignatureValidator()
    parent_conn.send(validator)

    record = DHTRecord(key=b'key', subkey=b'subkey' + validator.local_public_key,
                       value=b'value', expiration_time=get_dht_time() + 10)
    parent_conn.send(record)

    signed_record = parent_conn.recv()
    assert b'[signature:' in signed_record.value
    assert validator.validate(signed_record)
Beispiel #16
0
def test_validator_instance_is_picklable():
    # Needs to be picklable because the validator instance may be sent between processes

    original_validator = RSASignatureValidator()
    unpickled_validator = pickle.loads(pickle.dumps(original_validator))

    # To check that the private key was pickled and unpickled correctly, we sign a record
    # with the original public key using the unpickled validator and then validate the signature

    record = DHTRecord(key=b'key', subkey=b'subkey' + original_validator.local_public_key,
                       value=b'value', expiration_time=get_dht_time() + 10)
    signed_record = dataclasses.replace(record, value=unpickled_validator.sign_value(record))

    assert b'[signature:' in signed_record.value
    assert original_validator.validate(signed_record)
    assert unpickled_validator.validate(signed_record)
Beispiel #17
0
async def test_merging_schema_validators(dht_nodes_with_schema):
    alice, bob = dht_nodes_with_schema

    class TrivialValidator(RecordValidatorBase):
        def validate(self, record: DHTRecord) -> bool:
            return True

    second_validator = TrivialValidator()
    # Can't merge with the validator of the different type
    assert not alice.protocol.record_validator.merge_with(second_validator)

    class SecondSchema(BaseModel):
        some_field: StrictInt
        another_field: str

    class ThirdSchema(BaseModel):
        another_field: StrictInt  # Allow it to be a StrictInt as well

    for schema in [SecondSchema, ThirdSchema]:
        new_validator = SchemaValidator(schema, allow_extra_keys=False)
        for peer in [alice, bob]:
            assert peer.protocol.record_validator.merge_with(new_validator)

    assert await bob.store('experiment_name', b'foo_bar', get_dht_time() + 10)
    assert await bob.store('some_field', 777, get_dht_time() + 10)
    assert not await bob.store('some_field', 'string_value',
                               get_dht_time() + 10)
    assert await bob.store('another_field', 42, get_dht_time() + 10)
    assert await bob.store('another_field', 'string_value',
                           get_dht_time() + 10)

    # Unknown keys are allowed since the first schema is created with allow_extra_keys=True
    assert await bob.store('unknown_key', 999, get_dht_time() + 10)

    for peer in [alice, bob]:
        assert (await peer.get('experiment_name',
                               latest=True)).value == b'foo_bar'
        assert (await peer.get('some_field', latest=True)).value == 777
        assert (await peer.get('another_field',
                               latest=True)).value == 'string_value'

        assert (await peer.get('unknown_key', latest=True)).value == 999
Beispiel #18
0
async def test_expecting_dictionary(dht_nodes_with_schema):
    alice, bob = dht_nodes_with_schema

    # Dictionary (bytes -> non-negative int) expected
    assert await bob.store('n_batches', 777, get_dht_time() + 10, subkey=b'uid1')
    assert await bob.store('n_batches', 778, get_dht_time() + 10, subkey=b'uid2')
    assert not await bob.store('n_batches', -666, get_dht_time() + 10, subkey=b'uid3')
    assert not await bob.store('n_batches', 666, get_dht_time() + 10)
    assert not await bob.store('n_batches', b'not_integer', get_dht_time() + 10, subkey=b'uid1')
    assert not await bob.store('n_batches', 666, get_dht_time() + 10, subkey=666)

    # Refuse storing a plain dictionary bypassing the DictionaryDHTValue convention
    assert not await bob.store('n_batches', {b'uid3': 779}, get_dht_time() + 10)

    # Refuse records despite https://pydantic-docs.helpmanual.io/usage/models/#data-conversion
    assert not await bob.store('n_batches', 779.5, get_dht_time() + 10, subkey=b'uid3')
    assert not await bob.store('n_batches', 779.0, get_dht_time() + 10, subkey=b'uid3')
    assert not await bob.store('n_batches', [], get_dht_time() + 10)
    assert not await bob.store('n_batches', [(b'uid3', 779)], get_dht_time() + 10)

    # Refuse records despite https://github.com/samuelcolvin/pydantic/issues/1268
    assert not await bob.store('n_batches', '', get_dht_time() + 10)

    for peer in [alice, bob]:
        dictionary = (await peer.get('n_batches', latest=True)).value
        assert (len(dictionary) == 2 and
                dictionary[b'uid1'].value == 777 and
                dictionary[b'uid2'].value == 778)