Ejemplo n.º 1
0
    def __init__(self,
                 hs: "HomeServer",
                 key_fetchers: "Optional[Iterable[KeyFetcher]]" = None):
        self.clock = hs.get_clock()

        if key_fetchers is None:
            key_fetchers = (
                StoreKeyFetcher(hs),
                PerspectivesKeyFetcher(hs),
                ServerKeyFetcher(hs),
            )
        self._key_fetchers = key_fetchers

        self._server_queue: BatchingQueue[_FetchKeyRequest, Dict[str, Dict[
            str, FetchKeyResult]]] = BatchingQueue(
                "keyring_server",
                clock=hs.get_clock(),
                process_batch_callback=self._inner_fetch_key_requests,
            )

        self._hostname = hs.hostname

        # build a FetchKeyResult for each of our own keys, to shortcircuit the
        # fetcher.
        self._local_verify_keys: Dict[str, FetchKeyResult] = {}
        for key_id, key in hs.config.key.old_signing_keys.items():
            self._local_verify_keys[key_id] = FetchKeyResult(
                verify_key=key, valid_until_ts=key.expired_ts)

        vk = get_verify_key(hs.signing_key)
        self._local_verify_keys[f"{vk.alg}:{vk.version}"] = FetchKeyResult(
            verify_key=vk,
            valid_until_ts=2**63,  # fake future timestamp
        )
Ejemplo n.º 2
0
    def test_get_server_verify_keys(self):
        store = self.hs.get_datastore()

        key_id_1 = "ed25519:key1"
        key_id_2 = "ed25519:KEY_ID_2"
        d = store.store_server_verify_keys(
            "from_server",
            10,
            [
                ("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
                ("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
            ],
        )
        self.get_success(d)

        d = store.get_server_verify_keys([("server1", key_id_1),
                                          ("server1", key_id_2),
                                          ("server1", "ed25519:key3")])
        res = self.get_success(d)

        self.assertEqual(len(res.keys()), 3)
        res1 = res[("server1", key_id_1)]
        self.assertEqual(res1.verify_key, KEY_1)
        self.assertEqual(res1.verify_key.version, "key1")
        self.assertEqual(res1.valid_until_ts, 100)

        res2 = res[("server1", key_id_2)]
        self.assertEqual(res2.verify_key, KEY_2)
        # version comes from the ID it was stored with
        self.assertEqual(res2.verify_key.version, "KEY_ID_2")
        self.assertEqual(res2.valid_until_ts, 200)

        # non-existent result gives None
        self.assertIsNone(res[("server1", "ed25519:key3")])
Ejemplo n.º 3
0
    def test_cache(self):
        """Check that updates correctly invalidate the cache."""

        store = self.hs.get_datastore()

        key_id_1 = "ed25519:key1"
        key_id_2 = "ed25519:key2"

        d = store.store_server_verify_keys(
            "from_server",
            0,
            [
                ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
                ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
            ],
        )
        self.get_success(d)

        d = store.get_server_verify_keys([("srv1", key_id_1),
                                          ("srv1", key_id_2)])
        res = self.get_success(d)
        self.assertEqual(len(res.keys()), 2)

        res1 = res[("srv1", key_id_1)]
        self.assertEqual(res1.verify_key, KEY_1)
        self.assertEqual(res1.valid_until_ts, 100)

        res2 = res[("srv1", key_id_2)]
        self.assertEqual(res2.verify_key, KEY_2)
        self.assertEqual(res2.valid_until_ts, 200)

        # we should be able to look up the same thing again without a db hit
        res = store.get_server_verify_keys([("srv1", key_id_1)])
        if isinstance(res, Deferred):
            res = self.successResultOf(res)
        self.assertEqual(len(res.keys()), 1)
        self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)

        new_key_2 = signedjson.key.get_verify_key(
            signedjson.key.generate_signing_key("key2"))
        d = store.store_server_verify_keys(
            "from_server", 10,
            [("srv1", key_id_2, FetchKeyResult(new_key_2, 300))])
        self.get_success(d)

        d = store.get_server_verify_keys([("srv1", key_id_1),
                                          ("srv1", key_id_2)])
        res = self.get_success(d)
        self.assertEqual(len(res.keys()), 2)

        res1 = res[("srv1", key_id_1)]
        self.assertEqual(res1.verify_key, KEY_1)
        self.assertEqual(res1.valid_until_ts, 100)

        res2 = res[("srv1", key_id_2)]
        self.assertEqual(res2.verify_key, new_key_2)
        self.assertEqual(res2.valid_until_ts, 300)
Ejemplo n.º 4
0
 async def get_keys2(keys_to_fetch):
     self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
     return {
         "server1": {
             get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
         }
     }
Ejemplo n.º 5
0
 async def second_lookup_fetch(
         server_name: str, key_ids: List[str],
         minimum_valid_until_ts: int) -> Dict[str, FetchKeyResult]:
     # self.assertEquals(current_context().request.id, "context_12")
     return {
         get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
     }
Ejemplo n.º 6
0
 async def second_lookup_fetch(keys_to_fetch):
     self.assertEquals(current_context().request, "context_12")
     return {
         "server10": {
             get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
         }
     }
Ejemplo n.º 7
0
        def _get_keys(txn, batch):
            """Processes a batch of keys to fetch, and adds the result to `keys`."""

            # batch_iter always returns tuples so it's safe to do len(batch)
            sql = (
                "SELECT server_name, key_id, verify_key, ts_valid_until_ms "
                "FROM server_signature_keys WHERE 1=0"
            ) + " OR (server_name=? AND key_id=?)" * len(batch)

            txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))

            for row in txn:
                server_name, key_id, key_bytes, ts_valid_until_ms = row

                if ts_valid_until_ms is None:
                    # Old keys may be stored with a ts_valid_until_ms of null,
                    # in which case we treat this as if it was set to `0`, i.e.
                    # it won't match key requests that define a minimum
                    # `ts_valid_until_ms`.
                    ts_valid_until_ms = 0

                res = FetchKeyResult(
                    verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
                    valid_until_ts=ts_valid_until_ms,
                )
                keys[(server_name, key_id)] = res
Ejemplo n.º 8
0
 def get_keys1(keys_to_fetch):
     self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
     return defer.succeed(
         {
             "server1": {
                 get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)
             }
         }
     )
Ejemplo n.º 9
0
        async def get_keys(
                server_name: str, key_ids: List[str],
                minimum_valid_until_ts: int) -> Dict[str, FetchKeyResult]:
            self.assertEqual(server_name, self.hs.hostname)
            self.assertEqual(key_ids, [get_key_id(key2)])

            return {
                get_key_id(key2): FetchKeyResult(get_verify_key(key2), 1200)
            }
Ejemplo n.º 10
0
 async def get_keys2(
         server_name: str, key_ids: List[str],
         minimum_valid_until_ts: int) -> Dict[str, FetchKeyResult]:
     self.assertEqual(server_name, "server1")
     self.assertEqual(key_ids, [get_key_id(key1)])
     self.assertEqual(minimum_valid_until_ts, 1500)
     return {
         get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
     }
Ejemplo n.º 11
0
        async def get_keys(keys_to_fetch):
            # there should only be one request object (with the max validity)
            self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})

            return {
                "server1": {
                    get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
                }
            }
Ejemplo n.º 12
0
        async def first_lookup_fetch(keys_to_fetch):
            self.assertEquals(current_context().request, "context_11")
            self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}})

            await make_deferred_yieldable(first_lookup_deferred)
            return {
                "server10": {
                    get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
                }
            }
Ejemplo n.º 13
0
        async def get_keys(
                server_name: str, key_ids: List[str],
                minimum_valid_until_ts: int) -> Dict[str, FetchKeyResult]:
            # there should only be one request object (with the max validity)
            self.assertEqual(server_name, "server1")
            self.assertEqual(key_ids, [get_key_id(key1)])
            self.assertEqual(minimum_valid_until_ts, 1500)

            return {
                get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
            }
Ejemplo n.º 14
0
        async def first_lookup_fetch(
                server_name: str, key_ids: List[str],
                minimum_valid_until_ts: int) -> Dict[str, FetchKeyResult]:
            # self.assertEquals(current_context().request.id, "context_11")
            self.assertEqual(server_name, "server10")
            self.assertEqual(key_ids, [get_key_id(key1)])
            self.assertEqual(minimum_valid_until_ts, 0)

            await make_deferred_yieldable(first_lookup_deferred)
            return {
                get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
            }
Ejemplo n.º 15
0
    def test_verify_json_for_server_with_null_valid_until_ms(self):
        """Tests that we correctly handle key requests for keys we've stored
        with a null `ts_valid_until_ms`
        """
        mock_fetcher = keyring.KeyFetcher()
        mock_fetcher.get_keys = Mock(return_value=defer.succeed({}))

        kr = keyring.Keyring(
            self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher)
        )

        key1 = signedjson.key.generate_signing_key(1)
        r = self.hs.datastore.store_server_verify_keys(
            "server9",
            time.time() * 1000,
            [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))],
        )
        self.get_success(r)

        json1 = {}
        signedjson.sign.sign_json(json1, "server9", key1)

        # should fail immediately on an unsigned object
        d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
        self.failureResultOf(d, SynapseError)

        # should fail on a signed object with a non-zero minimum_valid_until_ms,
        # as it tries to refetch the keys and fails.
        d = _verify_json_for_server(
            kr, "server9", json1, 500, "test signed non-zero min"
        )
        self.get_failure(d, SynapseError)

        # We expect the keyring tried to refetch the key once.
        mock_fetcher.get_keys.assert_called_once_with(
            {"server9": {get_key_id(key1): 500}}
        )

        # should succeed on a signed object with a 0 minimum_valid_until_ms
        d = _verify_json_for_server(
            kr, "server9", json1, 0, "test signed with zero min"
        )
        self.get_success(d)
Ejemplo n.º 16
0
    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
        super().prepare(reactor, clock, hs)

        # poke the other server's signing key into the key store, so that we don't
        # make requests for it
        verify_key = signedjson.key.get_verify_key(
            self.OTHER_SERVER_SIGNATURE_KEY)
        verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version)

        self.get_success(hs.get_datastores().main.store_server_verify_keys(
            from_server=self.OTHER_SERVER_NAME,
            ts_added_ms=clock.time_msec(),
            verify_keys=[(
                self.OTHER_SERVER_NAME,
                verify_key_id,
                FetchKeyResult(
                    verify_key=verify_key,
                    valid_until_ts=clock.time_msec() + 1000,
                ),
            )],
        ))
Ejemplo n.º 17
0
    def test_verify_json_for_server(self):
        kr = keyring.Keyring(self.hs)

        key1 = signedjson.key.generate_signing_key(1)
        r = self.hs.datastore.store_server_verify_keys(
            "server9",
            time.time() * 1000,
            [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
        )
        self.get_success(r)

        json1 = {}
        signedjson.sign.sign_json(json1, "server9", key1)

        # should fail immediately on an unsigned object
        d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
        self.failureResultOf(d, SynapseError)

        # should suceed on a signed object
        d = _verify_json_for_server(kr, "server9", json1, 500, "test signed")
        # self.assertFalse(d.called)
        self.get_success(d)
Ejemplo n.º 18
0
    async def process_v2_response(
            self, from_server: str, response_json: JsonDict,
            time_added_ms: int) -> Dict[str, FetchKeyResult]:
        """Parse a 'Server Keys' structure from the result of a /key request

        This is used to parse either the entirety of the response from
        GET /_matrix/key/v2/server, or a single entry from the list returned by
        POST /_matrix/key/v2/query.

        Checks that each signature in the response that claims to come from the origin
        server is valid, and that there is at least one such signature.

        Stores the json in server_keys_json so that it can be used for future responses
        to /_matrix/key/v2/query.

        Args:
            from_server: the name of the server producing this result: either
                the origin server for a /_matrix/key/v2/server request, or the notary
                for a /_matrix/key/v2/query.

            response_json: the json-decoded Server Keys response object

            time_added_ms: the timestamp to record in server_keys_json

        Returns:
            Map from key_id to result object
        """
        ts_valid_until_ms = response_json["valid_until_ts"]

        # start by extracting the keys from the response, since they may be required
        # to validate the signature on the response.
        verify_keys = {}
        for key_id, key_data in response_json["verify_keys"].items():
            if is_signing_algorithm_supported(key_id):
                key_base64 = key_data["key"]
                key_bytes = decode_base64(key_base64)
                verify_key = decode_verify_key_bytes(key_id, key_bytes)
                verify_keys[key_id] = FetchKeyResult(
                    verify_key=verify_key, valid_until_ts=ts_valid_until_ms)

        server_name = response_json["server_name"]
        verified = False
        for key_id in response_json["signatures"].get(server_name, {}):
            key = verify_keys.get(key_id)
            if not key:
                # the key may not be present in verify_keys if:
                #  * we got the key from the notary server, and:
                #  * the key belongs to the notary server, and:
                #  * the notary server is using a different key to sign notary
                #    responses.
                continue

            verify_signed_json(response_json, server_name, key.verify_key)
            verified = True
            break

        if not verified:
            raise KeyLookupError(
                "Key response for %s is not signed by the origin server" %
                (server_name, ))

        for key_id, key_data in response_json["old_verify_keys"].items():
            if is_signing_algorithm_supported(key_id):
                key_base64 = key_data["key"]
                key_bytes = decode_base64(key_base64)
                verify_key = decode_verify_key_bytes(key_id, key_bytes)
                verify_keys[key_id] = FetchKeyResult(
                    verify_key=verify_key,
                    valid_until_ts=key_data["expired_ts"])

        key_json_bytes = encode_canonical_json(response_json)

        await make_deferred_yieldable(
            defer.gatherResults(
                [
                    run_in_background(
                        self.store.store_server_keys_json,
                        server_name=server_name,
                        key_id=key_id,
                        from_server=from_server,
                        ts_now_ms=time_added_ms,
                        ts_expires_ms=ts_valid_until_ms,
                        key_json_bytes=key_json_bytes,
                    ) for key_id in verify_keys
                ],
                consumeErrors=True,
            ).addErrback(unwrapFirstError))

        return verify_keys
Ejemplo n.º 19
0
    def process_v2_response(self, from_server, response_json, time_added_ms):
        """Parse a 'Server Keys' structure from the result of a /key request

        This is used to parse either the entirety of the response from
        GET /_matrix/key/v2/server, or a single entry from the list returned by
        POST /_matrix/key/v2/query.

        Checks that each signature in the response that claims to come from the origin
        server is valid, and that there is at least one such signature.

        Stores the json in server_keys_json so that it can be used for future responses
        to /_matrix/key/v2/query.

        Args:
            from_server (str): the name of the server producing this result: either
                the origin server for a /_matrix/key/v2/server request, or the notary
                for a /_matrix/key/v2/query.

            response_json (dict): the json-decoded Server Keys response object

            time_added_ms (int): the timestamp to record in server_keys_json

        Returns:
            Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
        """
        ts_valid_until_ms = response_json["valid_until_ts"]

        # start by extracting the keys from the response, since they may be required
        # to validate the signature on the response.
        verify_keys = {}
        for key_id, key_data in response_json["verify_keys"].items():
            if is_signing_algorithm_supported(key_id):
                key_base64 = key_data["key"]
                key_bytes = decode_base64(key_base64)
                verify_key = decode_verify_key_bytes(key_id, key_bytes)
                verify_keys[key_id] = FetchKeyResult(
                    verify_key=verify_key, valid_until_ts=ts_valid_until_ms)

        server_name = response_json["server_name"]
        verified = False
        for key_id in response_json["signatures"].get(server_name, {}):
            # each of the keys used for the signature must be present in the response
            # json.
            key = verify_keys.get(key_id)
            if not key:
                raise KeyLookupError(
                    "Key response is signed by key id %s:%s but that key is not "
                    "present in the response" % (server_name, key_id))

            verify_signed_json(response_json, server_name, key.verify_key)
            verified = True

        if not verified:
            raise KeyLookupError(
                "Key response for %s is not signed by the origin server" %
                (server_name, ))

        for key_id, key_data in response_json["old_verify_keys"].items():
            if is_signing_algorithm_supported(key_id):
                key_base64 = key_data["key"]
                key_bytes = decode_base64(key_base64)
                verify_key = decode_verify_key_bytes(key_id, key_bytes)
                verify_keys[key_id] = FetchKeyResult(
                    verify_key=verify_key,
                    valid_until_ts=key_data["expired_ts"])

        # re-sign the json with our own key, so that it is ready if we are asked to
        # give it out as a notary server
        signed_key_json = sign_json(response_json, self.config.server_name,
                                    self.config.signing_key[0])

        signed_key_json_bytes = encode_canonical_json(signed_key_json)

        yield make_deferred_yieldable(
            defer.gatherResults(
                [
                    run_in_background(
                        self.store.store_server_keys_json,
                        server_name=server_name,
                        key_id=key_id,
                        from_server=from_server,
                        ts_now_ms=time_added_ms,
                        ts_expires_ms=ts_valid_until_ms,
                        key_json_bytes=signed_key_json_bytes,
                    ) for key_id in verify_keys
                ],
                consumeErrors=True,
            ).addErrback(unwrapFirstError))

        defer.returnValue(verify_keys)