コード例 #1
0
ファイル: appservice.py プロジェクト: samuelyi/synapse
 def get_max_as_txn_id(txn: Cursor) -> int:
     logger.warning(
         "Falling back to slow query, you should port to postgres")
     txn.execute(
         "SELECT COALESCE(max(txn_id), 0) FROM application_services_txns"
     )
     return cast(Tuple[int], txn.fetchone())[0]
コード例 #2
0
        def process(txn: Cursor) -> int:
            last_stream = progress.get("last_stream", -(1 << 31))
            txn.execute(
                """
                UPDATE events SET stream_ordering2=stream_ordering
                WHERE stream_ordering IN (
                   SELECT stream_ordering FROM events WHERE stream_ordering > ?
                   ORDER BY stream_ordering LIMIT ?
                )
                RETURNING stream_ordering;
                """,
                (last_stream, batch_size),
            )
            row_count = txn.rowcount
            if row_count == 0:
                return 0
            last_stream = max(row[0] for row in txn)
            logger.info("populated stream_ordering2 up to %i", last_stream)

            self.db_pool.updates._background_update_progress_txn(
                txn,
                _BackgroundUpdates.POPULATE_STREAM_ORDERING2,
                {"last_stream": last_stream},
            )
            return row_count
コード例 #3
0
    def _get_e2e_cross_signing_signatures_for_devices_txn(
        self, txn: Cursor, device_query: Iterable[Tuple[str, str]]
    ) -> List[Tuple[str, str, str, str]]:
        """Get cross-signing signatures for a given list of devices

        Returns signatures made by the owners of the devices.

        Returns: a list of results; each entry in the list is a tuple of
            (user_id, key_id, target_device_id, signature).
        """
        signature_query_clauses = []
        signature_query_params = []

        for (user_id, device_id) in device_query:
            signature_query_clauses.append(
                "target_user_id = ? AND target_device_id = ? AND user_id = ?")
            signature_query_params.extend([user_id, device_id, user_id])

        signature_sql = """
            SELECT user_id, key_id, target_device_id, signature
            FROM e2e_cross_signing_signatures WHERE %s
            """ % (" OR ".join("(" + q + ")" for q in signature_query_clauses))

        txn.execute(signature_sql, signature_query_params)
        return txn.fetchall()
コード例 #4
0
ファイル: keys.py プロジェクト: xianliangjiang/synapse
        def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str]]) -> None:
            """Processes a batch of keys to fetch, and adds the result to `keys`."""

            # batch_iter always returns tuples so it's safe to do len(batch)
            sql = ("SELECT server_name, key_id, verify_key, ts_valid_until_ms "
                   "FROM server_signature_keys WHERE 1=0"
                   ) + " OR (server_name=? AND key_id=?)" * len(batch)

            txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))

            for row in txn:
                server_name, key_id, key_bytes, ts_valid_until_ms = row

                if ts_valid_until_ms is None:
                    # Old keys may be stored with a ts_valid_until_ms of null,
                    # in which case we treat this as if it was set to `0`, i.e.
                    # it won't match key requests that define a minimum
                    # `ts_valid_until_ms`.
                    ts_valid_until_ms = 0

                keys[(server_name, key_id)] = FetchKeyResult(
                    verify_key=decode_verify_key_bytes(key_id,
                                                       bytes(key_bytes)),
                    valid_until_ts=ts_valid_until_ms,
                )
コード例 #5
0
def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args,
               **kwargs):
    logger.info("Creating ignored_users table")
    execute_statements_from_stream(cur, StringIO(_create_commands))

    # We now upgrade existing data, if any. We don't do this in `run_upgrade` as
    # we a) want to run these before adding constraints and b) `run_upgrade` is
    # not run on empty databases.
    insert_sql = """
    INSERT INTO ignored_users (ignorer_user_id, ignored_user_id) VALUES (?, ?)
    """

    logger.info("Converting existing ignore lists")
    cur.execute(
        "SELECT user_id, content FROM account_data WHERE account_data_type = 'm.ignored_user_list'"
    )
    for user_id, content_json in cur.fetchall():
        content = db_to_json(content_json)

        # The content should be the form of a dictionary with a key
        # "ignored_users" pointing to a dictionary with keys of ignored users.
        #
        # { "ignored_users": "@someone:example.org": {} }
        ignored_users = content.get("ignored_users", {})
        if isinstance(ignored_users, dict) and ignored_users:
            cur.execute_batch(insert_sql,
                              [(user_id, u) for u in ignored_users])

    # Add indexes after inserting data for efficiency.
    logger.info("Adding constraints to ignored_users table")
    execute_statements_from_stream(cur, StringIO(_constraints_commands))
コード例 #6
0
        def get_rejected_events(
            txn: Cursor, ) -> List[Tuple[str, str, JsonDict, bool, bool]]:
            # Fetch rejected event json, their room version and whether we have
            # inserted them into the state_events or auth_events tables.
            #
            # Note we can assume that events that don't have a corresponding
            # room version are V1 rooms.
            sql = """
                SELECT DISTINCT
                    event_id,
                    COALESCE(room_version, '1'),
                    json,
                    state_events.event_id IS NOT NULL,
                    event_auth.event_id IS NOT NULL
                FROM rejections
                INNER JOIN event_json USING (event_id)
                LEFT JOIN rooms USING (room_id)
                LEFT JOIN state_events USING (event_id)
                LEFT JOIN event_auth USING (event_id)
                WHERE event_id > ?
                ORDER BY event_id
                LIMIT ?
            """

            txn.execute(
                sql,
                (
                    last_event_id,
                    batch_size,
                ),
            )

            return [(row[0], row[1], db_to_json(row[2]), row[3], row[4])
                    for row in txn]  # type: ignore
コード例 #7
0
 def _get_event_reference_hashes_txn(self, txn: Cursor,
                                     event_id: str) -> Dict[str, bytes]:
     """Get all the hashes for a given PDU.
     Args:
         txn:
         event_id: Id for the Event.
     Returns:
         A mapping of algorithm -> hash.
     """
     query = ("SELECT algorithm, hash"
              " FROM event_reference_hashes"
              " WHERE event_id = ?")
     txn.execute(query, (event_id, ))
     return {k: v for k, v in txn}
コード例 #8
0
def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args,
               **kwargs):
    # some instances might already have this index, in which case we can skip this
    if isinstance(database_engine, PostgresEngine):
        cur.execute("""
            SELECT 1 FROM pg_class WHERE relkind = 'i'
            AND relname = 'device_lists_outbound_last_success_unique_idx'
            """)

        if cur.rowcount:
            logger.info(
                "Unique index exists on device_lists_outbound_last_success: "
                "skipping rebuild")
            return

    logger.info(
        "Rebuilding device_lists_outbound_last_success with unique index")
    execute_statements_from_stream(cur, StringIO(_rebuild_commands))
コード例 #9
0
    def _update_stream_positions_table_txn(self, txn: Cursor):
        """Update the `stream_positions` table with newly persisted position."""

        if not self._writers:
            return

        # We upsert the value, ensuring on conflict that we always increase the
        # value (or decrease if stream goes backwards).
        sql = """
            INSERT INTO stream_positions (stream_name, instance_name, stream_id)
            VALUES (?, ?, ?)
            ON CONFLICT (stream_name, instance_name)
            DO UPDATE SET
                stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id)
        """ % {
            "agg": "GREATEST" if self._positive else "LEAST",
        }

        pos = (self.get_current_token_for_writer(self._instance_name), )
        txn.execute(sql, (self._stream_name, self._instance_name, pos))
コード例 #10
0
def _get_or_create_schema_state(
    txn: Cursor, database_engine: BaseDatabaseEngine
) -> Optional[Tuple[int, List[str], bool]]:
    # Bluntly try creating the schema_version tables.
    sql_path = os.path.join(schema_path, "common", "schema_version.sql")
    executescript(txn, sql_path)

    txn.execute("SELECT version, upgraded FROM schema_version")
    row = txn.fetchone()

    if row is not None:
        current_version = int(row[0])
        txn.execute(
            "SELECT file FROM applied_schema_deltas WHERE version >= ?",
            (current_version, ),
        )
        applied_deltas = [d for d, in txn]
        upgraded = bool(row[1])
        return current_version, applied_deltas, upgraded

    return None
コード例 #11
0
ファイル: prepare_database.py プロジェクト: samuel-p/synapse
def _get_or_create_schema_state(
        txn: Cursor,
        database_engine: BaseDatabaseEngine) -> Optional[_SchemaState]:
    # Bluntly try creating the schema_version tables.
    sql_path = os.path.join(schema_path, "common", "schema_version.sql")
    executescript(txn, sql_path)

    txn.execute("SELECT version, upgraded FROM schema_version")
    row = txn.fetchone()

    if row is None:
        # new database
        return None

    current_version = int(row[0])
    upgraded = bool(row[1])

    compat_version: Optional[int] = None
    txn.execute("SELECT compat_version FROM schema_compat_version")
    row = txn.fetchone()
    if row is not None:
        compat_version = int(row[0])

    txn.execute(
        "SELECT file FROM applied_schema_deltas WHERE version >= ?",
        (current_version, ),
    )
    applied_deltas = tuple(d for d, in txn)

    return _SchemaState(
        current_version=current_version,
        compat_version=compat_version,
        applied_deltas=applied_deltas,
        upgraded=upgraded,
    )
コード例 #12
0
def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args,
               **kwargs):
    if not isinstance(database_engine, PostgresEngine):
        # this only applies to postgres - sqlite does not distinguish between big and
        # little ints.
        return

    # First add a new column to contain the bigger min_depth
    cur.execute("ALTER TABLE room_depth ADD COLUMN min_depth2 BIGINT")

    # Create a trigger which will keep it populated.
    cur.execute("""
        CREATE OR REPLACE FUNCTION populate_min_depth2() RETURNS trigger AS $BODY$
            BEGIN
                new.min_depth2 := new.min_depth;
                RETURN NEW;
            END;
        $BODY$ LANGUAGE plpgsql
        """)

    cur.execute("""
        CREATE TRIGGER populate_min_depth2_trigger BEFORE INSERT OR UPDATE ON room_depth
        FOR EACH ROW
        EXECUTE PROCEDURE populate_min_depth2()
        """)

    # Start a bg process to populate it for old rooms
    cur.execute("""
       INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
            (6103, 'populate_room_depth_min_depth2', '{}')
       """)

    # and another to switch them over once it completes.
    cur.execute("""
        INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
            (6103, 'replace_room_depth_min_depth', '{}', 'populate_room_depth2')
        """)
コード例 #13
0
def check_database_before_upgrade(cur: Cursor,
                                  database_engine: BaseDatabaseEngine,
                                  config: HomeServerConfig) -> None:
    """Called before upgrading an existing database to check that it is broadly sane
    compared with the configuration.
    """
    logger.info("Checking database for consistency with configuration...")

    # if there are any users in the database, check that the username matches our
    # configured server name.

    cur.execute("SELECT name FROM users LIMIT 1")
    rows = cur.fetchall()
    if not rows:
        return

    user_domain = get_domain_from_id(rows[0][0])
    if user_domain == config.server.server_name:
        return

    raise Exception(
        "Found users in database not native to %s!\n"
        "You cannot change a synapse server_name after it's been configured" %
        (config.server.server_name, ))
コード例 #14
0
def _apply_module_schema_files(
    cur: Cursor,
    database_engine: BaseDatabaseEngine,
    modname: str,
    names_and_streams: Iterable[Tuple[str, TextIO]],
) -> None:
    """Apply the module schemas for a single module

    Args:
        cur: database cursor
        database_engine: synapse database engine class
        modname: fully qualified name of the module
        names_and_streams: the names and streams of schemas to be applied
    """
    cur.execute(
        "SELECT file FROM applied_module_schemas WHERE module_name = ?",
        (modname, ),
    )
    applied_deltas = {d for d, in cur}
    for (name, stream) in names_and_streams:
        if name in applied_deltas:
            continue

        root_name, ext = os.path.splitext(name)
        if ext != ".sql":
            raise PrepareDatabaseException(
                "only .sql files are currently supported for module schemas")

        logger.info("applying schema %s for %s", name, modname)
        execute_statements_from_stream(cur, stream)

        # Mark as done.
        cur.execute(
            "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)",
            (modname, name),
        )
コード例 #15
0
def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs):
    # complain if the room_id in partial_state_events doesn't match
    # that in `events`. We already have a fk constraint which ensures that the event
    # exists in `events`, so all we have to do is raise if there is a row with a
    # matching stream_ordering but not a matching room_id.
    if isinstance(database_engine, Sqlite3Engine):
        cur.execute(
            """
            CREATE TRIGGER IF NOT EXISTS partial_state_events_bad_room_id
            BEFORE INSERT ON partial_state_events
            FOR EACH ROW
            BEGIN
                SELECT RAISE(ABORT, 'Incorrect room_id in partial_state_events')
                WHERE EXISTS (
                    SELECT 1 FROM events
                    WHERE events.event_id = NEW.event_id
                       AND events.room_id != NEW.room_id
                );
            END;
            """
        )
    elif isinstance(database_engine, PostgresEngine):
        cur.execute(
            """
            CREATE OR REPLACE FUNCTION check_partial_state_events() RETURNS trigger AS $BODY$
            BEGIN
                IF EXISTS (
                    SELECT 1 FROM events
                    WHERE events.event_id = NEW.event_id
                       AND events.room_id != NEW.room_id
                ) THEN
                    RAISE EXCEPTION 'Incorrect room_id in partial_state_events';
                END IF;
                RETURN NEW;
            END;
            $BODY$ LANGUAGE plpgsql;
            """
        )

        cur.execute(
            """
            CREATE TRIGGER check_partial_state_events BEFORE INSERT OR UPDATE ON partial_state_events
            FOR EACH ROW
            EXECUTE PROCEDURE check_partial_state_events()
            """
        )
    else:
        raise NotImplementedError("Unknown database engine")
コード例 #16
0
ファイル: store.py プロジェクト: vishnumg/synapse
 def get_max_state_group_txn(txn: Cursor):
     txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
     return txn.fetchone()[0]
コード例 #17
0
ファイル: postgres.py プロジェクト: matrix-org/synapse
 def lock_table(self, txn: Cursor, table: str) -> None:
     txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table, ))
コード例 #18
0
def execute_statements_from_stream(cur: Cursor, f: TextIO) -> None:
    for statement in get_statements(f):
        cur.execute(statement)
コード例 #19
0
def _upgrade_existing_database(
    cur: Cursor,
    current_version: int,
    applied_delta_files: List[str],
    upgraded: bool,
    database_engine: BaseDatabaseEngine,
    config: Optional[HomeServerConfig],
    databases: Collection[str],
    is_empty: bool = False,
) -> None:
    """Upgrades an existing physical database.

    Delta files can either be SQL stored in *.sql files, or python modules
    in *.py.

    There can be multiple delta files per version. Synapse will keep track of
    which delta files have been applied, and will apply any that haven't been
    even if there has been no version bump. This is useful for development
    where orthogonal schema changes may happen on separate branches.

    Different delta files for the same version *must* be orthogonal and give
    the same result when applied in any order. No guarantees are made on the
    order of execution of these scripts.

    This is a no-op of current_version == SCHEMA_VERSION.

    Example directory structure:

        schema/
            delta/
                11/
                    foo.sql
                    ...
                12/
                    foo.sql
                    bar.py
                ...
            full_schemas/
                ...

    In the example, if current_version is 11, then foo.sql will be run if and
    only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in
    some arbitrary order.

    Note: we apply the delta files from the specified data stores as well as
    those in the top-level schema. We apply all delta files across data stores
    for a version before applying those in the next version.

    Args:
        cur
        current_version: The current version of the schema.
        applied_delta_files: A list of deltas that have already been applied.
        upgraded: Whether the current version was generated by having
            applied deltas or from full schema file. If `True` the function
            will never apply delta files for the given `current_version`, since
            the current_version wasn't generated by applying those delta files.
        database_engine
        config:
            None if we are initialising a blank database, otherwise the application
            config
        databases: The names of the databases to instantiate
            on the given physical database.
        is_empty: Is this a blank database? I.e. do we need to run the
            upgrade portions of the delta scripts.
    """
    if is_empty:
        assert not applied_delta_files
    else:
        assert config

    is_worker = config and config.worker_app is not None

    if current_version > SCHEMA_VERSION:
        raise ValueError("Cannot use this database as it is too " +
                         "new for the server to understand")

    # some of the deltas assume that config.server_name is set correctly, so now
    # is a good time to run the sanity check.
    if not is_empty and "main" in databases:
        from synapse.storage.databases.main import check_database_before_upgrade

        assert config is not None
        check_database_before_upgrade(cur, database_engine, config)

    start_ver = current_version

    # if we got to this schema version by running a full_schema rather than a series
    # of deltas, we should not run the deltas for this version.
    if not upgraded:
        start_ver += 1

    logger.debug("applied_delta_files: %s", applied_delta_files)

    if isinstance(database_engine, PostgresEngine):
        specific_engine_extension = ".postgres"
    else:
        specific_engine_extension = ".sqlite"

    specific_engine_extensions = (".sqlite", ".postgres")

    for v in range(start_ver, SCHEMA_VERSION + 1):
        if not is_worker:
            logger.info("Applying schema deltas for v%d", v)

            cur.execute("DELETE FROM schema_version")
            cur.execute(
                "INSERT INTO schema_version (version, upgraded) VALUES (?,?)",
                (v, True),
            )
        else:
            logger.info("Checking schema deltas for v%d", v)

        # We need to search both the global and per data store schema
        # directories for schema updates.

        # First we find the directories to search in
        delta_dir = os.path.join(schema_path, "common", "delta", str(v))
        directories = [delta_dir]
        for database in databases:
            directories.append(
                os.path.join(schema_path, database, "delta", str(v)))

        # Used to check if we have any duplicate file names
        file_name_counter = Counter()  # type: CounterType[str]

        # Now find which directories have anything of interest.
        directory_entries = []  # type: List[_DirectoryListing]
        for directory in directories:
            logger.debug("Looking for schema deltas in %s", directory)
            try:
                file_names = os.listdir(directory)
                directory_entries.extend(
                    _DirectoryListing(file_name,
                                      os.path.join(directory, file_name))
                    for file_name in file_names)

                for file_name in file_names:
                    file_name_counter[file_name] += 1
            except FileNotFoundError:
                # Data stores can have empty entries for a given version delta.
                pass
            except OSError:
                raise UpgradeDatabaseException(
                    "Could not open delta dir for version %d: %s" %
                    (v, directory))

        duplicates = {
            file_name
            for file_name, count in file_name_counter.items() if count > 1
        }
        if duplicates:
            # We don't support using the same file name in the same delta version.
            raise PrepareDatabaseException(
                "Found multiple delta files with the same name in v%d: %s" % (
                    v,
                    duplicates,
                ))

        # We sort to ensure that we apply the delta files in a consistent
        # order (to avoid bugs caused by inconsistent directory listing order)
        directory_entries.sort()
        for entry in directory_entries:
            file_name = entry.file_name
            relative_path = os.path.join(str(v), file_name)
            absolute_path = entry.absolute_path

            logger.debug("Found file: %s (%s)", relative_path, absolute_path)
            if relative_path in applied_delta_files:
                continue

            root_name, ext = os.path.splitext(file_name)

            if ext == ".py":
                # This is a python upgrade module. We need to import into some
                # package and then execute its `run_upgrade` function.
                if is_worker:
                    raise PrepareDatabaseException(
                        UNAPPLIED_DELTA_ON_WORKER_ERROR % relative_path)

                module_name = "synapse.storage.v%d_%s" % (v, root_name)

                spec = importlib.util.spec_from_file_location(
                    module_name, absolute_path)
                module = importlib.util.module_from_spec(spec)
                spec.loader.exec_module(module)  # type: ignore

                logger.info("Running script %s", relative_path)
                module.run_create(cur, database_engine)  # type: ignore
                if not is_empty:
                    module.run_upgrade(cur, database_engine,
                                       config=config)  # type: ignore
            elif ext == ".pyc" or file_name == "__pycache__":
                # Sometimes .pyc files turn up anyway even though we've
                # disabled their generation; e.g. from distribution package
                # installers. Silently skip it
                continue
            elif ext == ".sql":
                # A plain old .sql file, just read and execute it
                if is_worker:
                    raise PrepareDatabaseException(
                        UNAPPLIED_DELTA_ON_WORKER_ERROR % relative_path)
                logger.info("Applying schema %s", relative_path)
                executescript(cur, absolute_path)
            elif ext == specific_engine_extension and root_name.endswith(
                    ".sql"):
                # A .sql file specific to our engine; just read and execute it
                if is_worker:
                    raise PrepareDatabaseException(
                        UNAPPLIED_DELTA_ON_WORKER_ERROR % relative_path)
                logger.info("Applying engine-specific schema %s",
                            relative_path)
                executescript(cur, absolute_path)
            elif ext in specific_engine_extensions and root_name.endswith(
                    ".sql"):
                # A .sql file for a different engine; skip it.
                continue
            else:
                # Not a valid delta file.
                logger.warning(
                    "Found directory entry that did not end in .py or .sql: %s",
                    relative_path,
                )
                continue

            # Mark as done.
            cur.execute(
                "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)",
                (v, relative_path),
            )

    logger.info("Schema now up to date")
コード例 #20
0
def _setup_new_database(cur: Cursor, database_engine: BaseDatabaseEngine,
                        databases: Collection[str]) -> None:
    """Sets up the physical database by finding a base set of "full schemas" and
    then applying any necessary deltas, including schemas from the given data
    stores.

    The "full_schemas" directory has subdirectories named after versions. This
    function searches for the highest version less than or equal to
    `SCHEMA_VERSION` and executes all .sql files in that directory.

    The function will then apply all deltas for all versions after the base
    version.

    Example directory structure:

    schema/
        common/
            delta/
                ...
            full_schemas/
                11/
                    foo.sql
        main/
            delta/
                ...
            full_schemas/
                3/
                    test.sql
                    ...
                11/
                    bar.sql
                ...

    In the example foo.sql and bar.sql would be run, and then any delta files
    for versions strictly greater than 11.

    Note: we apply the full schemas and deltas from the `schema/common`
    folder as well those in the databases specified.

    Args:
        cur: a database cursor
        database_engine
        databases: The names of the databases to instantiate on the given physical database.
    """

    # We're about to set up a brand new database so we check that its
    # configured to our liking.
    database_engine.check_new_database(cur)

    full_schemas_dir = os.path.join(schema_path, "common", "full_schemas")

    # First we find the highest full schema version we have
    valid_versions = []

    for filename in os.listdir(full_schemas_dir):
        try:
            ver = int(filename)
        except ValueError:
            continue

        if ver <= SCHEMA_VERSION:
            valid_versions.append(ver)

    if not valid_versions:
        raise PrepareDatabaseException(
            "Could not find a suitable base set of full schemas")

    max_current_ver = max(valid_versions)

    logger.debug("Initialising schema v%d", max_current_ver)

    # Now let's find all the full schema files, both in the common schema and
    # in database schemas.
    directories = [os.path.join(full_schemas_dir, str(max_current_ver))]
    directories.extend(
        os.path.join(
            schema_path,
            database,
            "full_schemas",
            str(max_current_ver),
        ) for database in databases)

    directory_entries = []  # type: List[_DirectoryListing]
    for directory in directories:
        directory_entries.extend(
            _DirectoryListing(file_name, os.path.join(directory, file_name))
            for file_name in os.listdir(directory))

    if isinstance(database_engine, PostgresEngine):
        specific = "postgres"
    else:
        specific = "sqlite"

    directory_entries.sort()
    for entry in directory_entries:
        if entry.file_name.endswith(".sql") or entry.file_name.endswith(
                ".sql." + specific):
            logger.debug("Applying schema %s", entry.absolute_path)
            executescript(cur, entry.absolute_path)

    cur.execute(
        "INSERT INTO schema_version (version, upgraded) VALUES (?,?)",
        (max_current_ver, False),
    )

    _upgrade_existing_database(
        cur,
        current_version=max_current_ver,
        applied_delta_files=[],
        upgraded=False,
        database_engine=database_engine,
        config=None,
        databases=databases,
        is_empty=True,
    )
コード例 #21
0
 def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
     txn.execute("SELECT nextval(?) FROM generate_series(1, ?)",
                 (self._sequence_name, n))
     return [i for (i, ) in txn]
コード例 #22
0
ファイル: sequence.py プロジェクト: yvwvnacb/synapse
 def get_next_id_txn(self, txn: Cursor) -> int:
     txn.execute("SELECT nextval(?)", (self._sequence_name, ))
     return txn.fetchone()[0]
コード例 #23
0
ファイル: postgres.py プロジェクト: matrix-org/synapse
 def get_db_locale(self, txn: Cursor) -> Tuple[str, str]:
     txn.execute(
         "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
     )
     collation, ctype = cast(Tuple[str, str], txn.fetchone())
     return collation, ctype
コード例 #24
0
    def _get_auth_chain_ids_using_cover_index_txn(
            self, txn: Cursor, room_id: str, event_ids: Collection[str],
            include_given: bool) -> List[str]:
        """Calculates the auth chain IDs using the chain index."""

        # First we look up the chain ID/sequence numbers for the given events.

        initial_events = set(event_ids)

        # All the events that we've found that are reachable from the events.
        seen_events = set()  # type: Set[str]

        # A map from chain ID to max sequence number of the given events.
        event_chains = {}  # type: Dict[int, int]

        sql = """
            SELECT event_id, chain_id, sequence_number
            FROM event_auth_chains
            WHERE %s
        """
        for batch in batch_iter(initial_events, 1000):
            clause, args = make_in_list_sql_clause(txn.database_engine,
                                                   "event_id", batch)
            txn.execute(sql % (clause, ), args)

            for event_id, chain_id, sequence_number in txn:
                seen_events.add(event_id)
                event_chains[chain_id] = max(sequence_number,
                                             event_chains.get(chain_id, 0))

        # Check that we actually have a chain ID for all the events.
        events_missing_chain_info = initial_events.difference(seen_events)
        if events_missing_chain_info:
            # This can happen due to e.g. downgrade/upgrade of the server. We
            # raise an exception and fall back to the previous algorithm.
            logger.info(
                "Unexpectedly found that events don't have chain IDs in room %s: %s",
                room_id,
                events_missing_chain_info,
            )
            raise _NoChainCoverIndex(room_id)

        # Now we look up all links for the chains we have, adding chains that
        # are reachable from any event.
        sql = """
            SELECT
                origin_chain_id, origin_sequence_number,
                target_chain_id, target_sequence_number
            FROM event_auth_chain_links
            WHERE %s
        """

        # A map from chain ID to max sequence number *reachable* from any event ID.
        chains = {}  # type: Dict[int, int]

        # Add all linked chains reachable from initial set of chains.
        for batch in batch_iter(event_chains, 1000):
            clause, args = make_in_list_sql_clause(txn.database_engine,
                                                   "origin_chain_id", batch)
            txn.execute(sql % (clause, ), args)

            for (
                    origin_chain_id,
                    origin_sequence_number,
                    target_chain_id,
                    target_sequence_number,
            ) in txn:
                # chains are only reachable if the origin sequence number of
                # the link is less than the max sequence number in the
                # origin chain.
                if origin_sequence_number <= event_chains.get(
                        origin_chain_id, 0):
                    chains[target_chain_id] = max(
                        target_sequence_number,
                        chains.get(target_chain_id, 0),
                    )

        # Add the initial set of chains, excluding the sequence corresponding to
        # initial event.
        for chain_id, seq_no in event_chains.items():
            chains[chain_id] = max(seq_no - 1, chains.get(chain_id, 0))

        # Now for each chain we figure out the maximum sequence number reachable
        # from *any* event ID. Events with a sequence less than that are in the
        # auth chain.
        if include_given:
            results = initial_events
        else:
            results = set()

        if isinstance(self.database_engine, PostgresEngine):
            # We can use `execute_values` to efficiently fetch the gaps when
            # using postgres.
            sql = """
                SELECT event_id
                FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
                WHERE
                    c.chain_id = l.chain_id
                    AND sequence_number <= max_seq
            """

            rows = txn.execute_values(sql, chains.items())
            results.update(r for r, in rows)
        else:
            # For SQLite we just fall back to doing a noddy for loop.
            sql = """
                SELECT event_id FROM event_auth_chains
                WHERE chain_id = ? AND sequence_number <= ?
            """
            for chain_id, max_no in chains.items():
                txn.execute(sql, (chain_id, max_no))
                results.update(r for r, in txn)

        return list(results)
コード例 #25
0
    def _calculate_chain_cover_txn(
        self,
        txn: Cursor,
        last_room_id: str,
        last_depth: int,
        last_stream: int,
        batch_size: Optional[int],
        single_room: bool,
    ) -> _CalculateChainCover:
        """Calculate the chain cover for `batch_size` events, ordered by
        `(room_id, depth, stream)`.

        Args:
            txn,
            last_room_id, last_depth, last_stream: The `(room_id, depth, stream)`
                tuple to fetch results after.
            batch_size: The maximum number of events to process. If None then
                no limit.
            single_room: Whether to calculate the index for just the given
                room.
        """

        # Get the next set of events in the room (that we haven't already
        # computed chain cover for). We do this in topological order.

        # We want to do a `(topological_ordering, stream_ordering) > (?,?)`
        # comparison, but that is not supported on older SQLite versions
        tuple_clause, tuple_args = make_tuple_comparison_clause([
            ("events.room_id", last_room_id),
            ("topological_ordering", last_depth),
            ("stream_ordering", last_stream),
        ], )

        extra_clause = ""
        if single_room:
            extra_clause = "AND events.room_id = ?"
            tuple_args.append(last_room_id)

        sql = """
            SELECT
                event_id, state_events.type, state_events.state_key,
                topological_ordering, stream_ordering,
                events.room_id
            FROM events
            INNER JOIN state_events USING (event_id)
            LEFT JOIN event_auth_chains USING (event_id)
            LEFT JOIN event_auth_chain_to_calculate USING (event_id)
            WHERE event_auth_chains.event_id IS NULL
                AND event_auth_chain_to_calculate.event_id IS NULL
                AND %(tuple_cmp)s
                %(extra)s
            ORDER BY events.room_id, topological_ordering, stream_ordering
            %(limit)s
        """ % {
            "tuple_cmp": tuple_clause,
            "limit": "LIMIT ?" if batch_size is not None else "",
            "extra": extra_clause,
        }

        if batch_size is not None:
            tuple_args.append(batch_size)

        txn.execute(sql, tuple_args)
        rows = txn.fetchall()

        # Put the results in the necessary format for
        # `_add_chain_cover_index`
        event_to_room_id = {row[0]: row[5] for row in rows}
        event_to_types = {row[0]: (row[1], row[2]) for row in rows}

        # Calculate the new last position we've processed up to.
        new_last_depth: int = rows[-1][3] if rows else last_depth
        new_last_stream: int = rows[-1][4] if rows else last_stream
        new_last_room_id: str = rows[-1][5] if rows else ""

        # Map from room_id to last depth/stream_ordering processed for the room,
        # excluding the last room (which we're likely still processing). We also
        # need to include the room passed in if it's not included in the result
        # set (as we then know we've processed all events in said room).
        #
        # This is the set of rooms that we can now safely flip the
        # `has_auth_chain_index` bit for.
        finished_rooms = {
            row[5]: (row[3], row[4])
            for row in rows if row[5] != new_last_room_id
        }
        if last_room_id not in finished_rooms and last_room_id != new_last_room_id:
            finished_rooms[last_room_id] = (last_depth, last_stream)

        count = len(rows)

        # We also need to fetch the auth events for them.
        auth_events = self.db_pool.simple_select_many_txn(
            txn,
            table="event_auth",
            column="event_id",
            iterable=event_to_room_id,
            keyvalues={},
            retcols=("event_id", "auth_id"),
        )

        event_to_auth_chain: Dict[str, List[str]] = {}
        for row in auth_events:
            event_to_auth_chain.setdefault(row["event_id"],
                                           []).append(row["auth_id"])

        # Calculate and persist the chain cover index for this set of events.
        #
        # Annoyingly we need to gut wrench into the persit event store so that
        # we can reuse the function to calculate the chain cover for rooms.
        PersistEventsStore._add_chain_cover_index(
            txn,
            self.db_pool,
            self.event_chain_id_gen,
            event_to_room_id,
            event_to_types,
            event_to_auth_chain,
        )

        return _CalculateChainCover(
            room_id=new_last_room_id,
            depth=new_last_depth,
            stream=new_last_stream,
            processed_count=count,
            finished_room_map=finished_rooms,
        )
コード例 #26
0
 def get_background_updates_txn(txn: Cursor) -> List[Dict[str, Any]]:
     txn.execute("""
         SELECT update_name, depends_on FROM background_updates
         ORDER BY ordering, update_name
         """)
     return self.db_pool.cursor_to_dict(txn)
コード例 #27
0
 def process(txn: Cursor) -> None:
     for sql in _REPLACE_STREAM_ORDERING_SQL_COMMANDS:
         logger.info("completing stream_ordering migration: %s", sql)
         txn.execute(sql)
コード例 #28
0
 def get_next_id_txn(self, txn: Cursor) -> int:
     txn.execute("SELECT nextval(?)", (self._sequence_name, ))
     fetch_res = txn.fetchone()
     assert fetch_res is not None
     return fetch_res[0]
コード例 #29
0
    def _get_auth_chain_difference_using_cover_index_txn(
            self, txn: Cursor, room_id: str,
            state_sets: List[Set[str]]) -> Set[str]:
        """Calculates the auth chain difference using the chain index.

        See docs/auth_chain_difference_algorithm.md for details
        """

        # First we look up the chain ID/sequence numbers for all the events, and
        # work out the chain/sequence numbers reachable from each state set.

        initial_events = set(state_sets[0]).union(*state_sets[1:])

        # Map from event_id -> (chain ID, seq no)
        chain_info = {}  # type: Dict[str, Tuple[int, int]]

        # Map from chain ID -> seq no -> event Id
        chain_to_event = {}  # type: Dict[int, Dict[int, str]]

        # All the chains that we've found that are reachable from the state
        # sets.
        seen_chains = set()  # type: Set[int]

        sql = """
            SELECT event_id, chain_id, sequence_number
            FROM event_auth_chains
            WHERE %s
        """
        for batch in batch_iter(initial_events, 1000):
            clause, args = make_in_list_sql_clause(txn.database_engine,
                                                   "event_id", batch)
            txn.execute(sql % (clause, ), args)

            for event_id, chain_id, sequence_number in txn:
                chain_info[event_id] = (chain_id, sequence_number)
                seen_chains.add(chain_id)
                chain_to_event.setdefault(chain_id,
                                          {})[sequence_number] = event_id

        # Check that we actually have a chain ID for all the events.
        events_missing_chain_info = initial_events.difference(chain_info)
        if events_missing_chain_info:
            # This can happen due to e.g. downgrade/upgrade of the server. We
            # raise an exception and fall back to the previous algorithm.
            logger.info(
                "Unexpectedly found that events don't have chain IDs in room %s: %s",
                room_id,
                events_missing_chain_info,
            )
            raise _NoChainCoverIndex(room_id)

        # Corresponds to `state_sets`, except as a map from chain ID to max
        # sequence number reachable from the state set.
        set_to_chain = []  # type: List[Dict[int, int]]
        for state_set in state_sets:
            chains = {}  # type: Dict[int, int]
            set_to_chain.append(chains)

            for event_id in state_set:
                chain_id, seq_no = chain_info[event_id]

                chains[chain_id] = max(seq_no, chains.get(chain_id, 0))

        # Now we look up all links for the chains we have, adding chains to
        # set_to_chain that are reachable from each set.
        sql = """
            SELECT
                origin_chain_id, origin_sequence_number,
                target_chain_id, target_sequence_number
            FROM event_auth_chain_links
            WHERE %s
        """

        # (We need to take a copy of `seen_chains` as we want to mutate it in
        # the loop)
        for batch in batch_iter(set(seen_chains), 1000):
            clause, args = make_in_list_sql_clause(txn.database_engine,
                                                   "origin_chain_id", batch)
            txn.execute(sql % (clause, ), args)

            for (
                    origin_chain_id,
                    origin_sequence_number,
                    target_chain_id,
                    target_sequence_number,
            ) in txn:
                for chains in set_to_chain:
                    # chains are only reachable if the origin sequence number of
                    # the link is less than the max sequence number in the
                    # origin chain.
                    if origin_sequence_number <= chains.get(
                            origin_chain_id, 0):
                        chains[target_chain_id] = max(
                            target_sequence_number,
                            chains.get(target_chain_id, 0),
                        )

                seen_chains.add(target_chain_id)

        # Now for each chain we figure out the maximum sequence number reachable
        # from *any* state set and the minimum sequence number reachable from
        # *all* state sets. Events in that range are in the auth chain
        # difference.
        result = set()

        # Mapping from chain ID to the range of sequence numbers that should be
        # pulled from the database.
        chain_to_gap = {}  # type: Dict[int, Tuple[int, int]]

        for chain_id in seen_chains:
            min_seq_no = min(
                chains.get(chain_id, 0) for chains in set_to_chain)
            max_seq_no = max(
                chains.get(chain_id, 0) for chains in set_to_chain)

            if min_seq_no < max_seq_no:
                # We have a non empty gap, try and fill it from the events that
                # we have, otherwise add them to the list of gaps to pull out
                # from the DB.
                for seq_no in range(min_seq_no + 1, max_seq_no + 1):
                    event_id = chain_to_event.get(chain_id, {}).get(seq_no)
                    if event_id:
                        result.add(event_id)
                    else:
                        chain_to_gap[chain_id] = (min_seq_no, max_seq_no)
                        break

        if not chain_to_gap:
            # If there are no gaps to fetch, we're done!
            return result

        if isinstance(self.database_engine, PostgresEngine):
            # We can use `execute_values` to efficiently fetch the gaps when
            # using postgres.
            sql = """
                SELECT event_id
                FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq)
                WHERE
                    c.chain_id = l.chain_id
                    AND min_seq < sequence_number AND sequence_number <= max_seq
            """

            args = [(chain_id, min_no, max_no)
                    for chain_id, (min_no, max_no) in chain_to_gap.items()]

            rows = txn.execute_values(sql, args)
            result.update(r for r, in rows)
        else:
            # For SQLite we just fall back to doing a noddy for loop.
            sql = """
                SELECT event_id FROM event_auth_chains
                WHERE chain_id = ? AND ? < sequence_number AND sequence_number <= ?
            """
            for chain_id, (min_no, max_no) in chain_to_gap.items():
                txn.execute(sql, (chain_id, min_no, max_no))
                result.update(r for r, in txn)

        return result