Exemplo n.º 1
0
def setup_test_homeserver(cleanup_func,
                          name="test",
                          config=None,
                          reactor=None,
                          homeserver_to_use: Type[HomeServer] = TestHomeServer,
                          **kwargs):
    """
    Setup a homeserver suitable for running tests against.  Keyword arguments
    are passed to the Homeserver constructor.

    If no datastore is supplied, one is created and given to the homeserver.

    Args:
        cleanup_func : The function used to register a cleanup routine for
                       after the test.

    Calling this method directly is deprecated: you should instead derive from
    HomeserverTestCase.
    """
    if reactor is None:
        from twisted.internet import reactor

    if config is None:
        config = default_config(name, parse=True)

    config.ldap_enabled = False

    if "clock" not in kwargs:
        kwargs["clock"] = MockClock()

    if USE_POSTGRES_FOR_TESTS:
        test_db = "synapse_test_%s" % uuid.uuid4().hex

        database_config = {
            "name": "psycopg2",
            "args": {
                "database": test_db,
                "host": POSTGRES_HOST,
                "password": POSTGRES_PASSWORD,
                "user": POSTGRES_USER,
                "cp_min": 1,
                "cp_max": 5,
            },
        }
    else:
        database_config = {
            "name": "sqlite3",
            "args": {
                "database": ":memory:",
                "cp_min": 1,
                "cp_max": 1
            },
        }

    database = DatabaseConnectionConfig("master", database_config)
    config.database.databases = [database]

    db_engine = create_engine(database.config)

    # Create the database before we actually try and connect to it, based off
    # the template database we generate in setupdb()
    if isinstance(db_engine, PostgresEngine):
        db_conn = db_engine.module.connect(
            database=POSTGRES_BASE_DB,
            user=POSTGRES_USER,
            host=POSTGRES_HOST,
            password=POSTGRES_PASSWORD,
        )
        db_conn.autocommit = True
        cur = db_conn.cursor()
        cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db, ))
        cur.execute("CREATE DATABASE %s WITH TEMPLATE %s;" %
                    (test_db, POSTGRES_BASE_DB))
        cur.close()
        db_conn.close()

    hs = homeserver_to_use(
        name,
        config=config,
        version_string="Synapse/tests",
        reactor=reactor,
    )

    # Install @cache_in_self attributes
    for key, val in kwargs.items():
        setattr(hs, "_" + key, val)

    # Mock TLS
    hs.tls_server_context_factory = Mock()
    hs.tls_client_options_factory = Mock()

    hs.setup()
    if homeserver_to_use == TestHomeServer:
        hs.setup_background_tasks()

    if isinstance(db_engine, PostgresEngine):
        database = hs.get_datastores().databases[0]

        # We need to do cleanup on PostgreSQL
        def cleanup():
            import psycopg2

            # Close all the db pools
            database._db_pool.close()

            dropped = False

            # Drop the test database
            db_conn = db_engine.module.connect(
                database=POSTGRES_BASE_DB,
                user=POSTGRES_USER,
                host=POSTGRES_HOST,
                password=POSTGRES_PASSWORD,
            )
            db_conn.autocommit = True
            cur = db_conn.cursor()

            # Try a few times to drop the DB. Some things may hold on to the
            # database for a few more seconds due to flakiness, preventing
            # us from dropping it when the test is over. If we can't drop
            # it, warn and move on.
            for x in range(5):
                try:
                    cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db, ))
                    db_conn.commit()
                    dropped = True
                except psycopg2.OperationalError as e:
                    warnings.warn("Couldn't drop old db: " + str(e),
                                  category=UserWarning)
                    time.sleep(0.5)

            cur.close()
            db_conn.close()

            if not dropped:
                warnings.warn("Failed to drop old DB.", category=UserWarning)

        if not LEAVE_DB:
            # Register the cleanup hook
            cleanup_func(cleanup)

    # bcrypt is far too slow to be doing in unit tests
    # Need to let the HS build an auth handler and then mess with it
    # because AuthHandler's constructor requires the HS, so we can't make one
    # beforehand and pass it in to the HS's constructor (chicken / egg)
    async def hash(p):
        return hashlib.md5(p.encode("utf8")).hexdigest()

    hs.get_auth_handler().hash = hash

    async def validate_hash(p, h):
        return hashlib.md5(p.encode("utf8")).hexdigest() == h

    hs.get_auth_handler().validate_hash = validate_hash

    fed = kwargs.get("resource_for_federation", None)
    if fed:
        register_federation_servlets(hs, fed)

    return hs
Exemplo n.º 2
0
    async def run(self) -> None:
        """Ports the SQLite database to a PostgreSQL database.

        When a fatal error is met, its message is assigned to the global "end_error"
        variable. When this error comes with a stacktrace, its exec_info is assigned to
        the global "end_error_exec_info" variable.
        """
        global end_error

        try:
            # we allow people to port away from outdated versions of sqlite.
            self.sqlite_store = self.build_db_store(
                DatabaseConnectionConfig("master-sqlite", self.sqlite_config),
                allow_outdated_version=True,
            )

            # Check if all background updates are done, abort if not.
            updates_complete = (await self.sqlite_store.db_pool.updates.
                                has_completed_background_updates())
            if not updates_complete:
                end_error = (
                    "Pending background updates exist in the SQLite3 database."
                    " Please start Synapse again and wait until every update has finished"
                    " before running this script.\n")
                return

            self.postgres_store = self.build_db_store(
                self.hs_config.database.get_single_database())

            await self.run_background_updates_on_postgres()

            self.progress.set_state("Creating port tables")

            def create_port_table(txn: LoggingTransaction) -> None:
                txn.execute("CREATE TABLE IF NOT EXISTS port_from_sqlite3 ("
                            " table_name varchar(100) NOT NULL UNIQUE,"
                            " forward_rowid bigint NOT NULL,"
                            " backward_rowid bigint NOT NULL"
                            ")")

            # The old port script created a table with just a "rowid" column.
            # We want people to be able to rerun this script from an old port
            # so that they can pick up any missing events that were not
            # ported across.
            def alter_table(txn: LoggingTransaction) -> None:
                txn.execute("ALTER TABLE IF EXISTS port_from_sqlite3"
                            " RENAME rowid TO forward_rowid")
                txn.execute("ALTER TABLE IF EXISTS port_from_sqlite3"
                            " ADD backward_rowid bigint NOT NULL DEFAULT 0")

            try:
                await self.postgres_store.db_pool.runInteraction(
                    "alter_table", alter_table)
            except Exception:
                # On Error Resume Next
                pass

            await self.postgres_store.db_pool.runInteraction(
                "create_port_table", create_port_table)

            # Step 2. Set up sequences
            #
            # We do this before porting the tables so that event if we fail half
            # way through the postgres DB always have sequences that are greater
            # than their respective tables. If we don't then creating the
            # `DataStore` object will fail due to the inconsistency.
            self.progress.set_state("Setting up sequence generators")
            await self._setup_state_group_id_seq()
            await self._setup_user_id_seq()
            await self._setup_events_stream_seqs()
            await self._setup_sequence(
                "device_inbox_sequence",
                ("device_inbox", "device_federation_outbox"))
            await self._setup_sequence(
                "account_data_sequence",
                ("room_account_data", "room_tags_revisions", "account_data"),
            )
            await self._setup_sequence("receipts_sequence",
                                       ("receipts_linearized", ))
            await self._setup_sequence("presence_stream_sequence",
                                       ("presence_stream", ))
            await self._setup_auth_chain_sequence()

            # Step 3. Get tables.
            self.progress.set_state("Fetching tables")
            sqlite_tables = await self.sqlite_store.db_pool.simple_select_onecol(
                table="sqlite_master",
                keyvalues={"type": "table"},
                retcol="name")

            postgres_tables = await self.postgres_store.db_pool.simple_select_onecol(
                table="information_schema.tables",
                keyvalues={},
                retcol="distinct table_name",
            )

            tables = set(sqlite_tables) & set(postgres_tables)
            logger.info("Found %d tables", len(tables))

            # Step 4. Figure out what still needs copying
            self.progress.set_state("Checking on port progress")
            setup_res = await make_deferred_yieldable(
                defer.gatherResults(
                    [
                        run_in_background(self.setup_table, table)
                        for table in tables if table not in
                        ["schema_version", "applied_schema_deltas"]
                        and not table.startswith("sqlite_")
                    ],
                    consumeErrors=True,
                ))
            # Map from table name to args passed to `handle_table`, i.e. a tuple
            # of: `postgres_size`, `table_size`, `forward_chunk`, `backward_chunk`.
            tables_to_port_info_map = {r[0]: r[1:] for r in setup_res}

            # Step 5. Do the copying.
            #
            # This is slightly convoluted as we need to ensure tables are ported
            # in the correct order due to foreign key constraints.
            self.progress.set_state("Copying to postgres")

            constraints = await self.get_table_constraints()
            tables_ported = set()  # type: Set[str]

            while tables_to_port_info_map:
                # Pulls out all tables that are still to be ported and which
                # only depend on tables that are already ported (if any).
                tables_to_port = [
                    table for table in tables_to_port_info_map
                    if not constraints.get(table, set()) - tables_ported
                ]

                await make_deferred_yieldable(
                    defer.gatherResults(
                        [
                            run_in_background(
                                self.handle_table,
                                table,
                                *tables_to_port_info_map.pop(table),
                            ) for table in tables_to_port
                        ],
                        consumeErrors=True,
                    ))

                tables_ported.update(tables_to_port)

            self.progress.done()
        except Exception as e:
            global end_error_exec_info
            end_error = str(e)
            # Type safety: we're in an exception handler, so the exc_info() tuple
            # will not be (None, None, None).
            end_error_exec_info = sys.exc_info()  # type: ignore[assignment]
            logger.exception("")
        finally:
            reactor.stop()