Exemple #1
0
    def bible(
        self,
        request: pytest.FixtureRequest,
        default_version: str,
        default_abbr: str,
        MockBible: type[Any],
    ) -> Any:
        name = cast(Callable[..., Any], cast(Any, request).function).__name__

        data: dict[str, Any]

        if name == 'test_search':
            data = request.getfixturevalue('search_data')
        elif name == 'test_get_passage':
            data = request.getfixturevalue('passage_data')
        else:
            data = {}

        return MockBible(
            command='bib',
            name='The Bible',
            abbr=data.get('abbr', default_abbr),
            service='MyService',
            service_version=data.get('version', default_version),
            rtl=False,
        )
Exemple #2
0
    def testStorageAndHydraulicExports(self, project: _Project, request: _pt.FixtureRequest):
        helper = _Helper(project)
        helper.setup()

        # The following line is required otherwise QT will crash
        application = _qtw.QApplication([])

        def quitApplication():
            application.quit()

        request.addfinalizer(quitApplication)

        projectFolderPath = helper.actualProjectFolderPath

        self._exportHydraulic(projectFolderPath, _format="mfs")
        mfsDdckRelativePath = f"{project.projectName}_mfs.dck"
        helper.ensureFilesAreEqual(mfsDdckRelativePath, shallReplaceRandomizedFlowRates=True)

        self._exportHydraulic(projectFolderPath, _format="ddck")
        hydraulicDdckRelativePath = "ddck/hydraulic/hydraulic.ddck"
        helper.ensureFilesAreEqual(hydraulicDdckRelativePath, shallReplaceRandomizedFlowRates=False)

        storageTankNames = self._exportStorageTanksAndGetNames(projectFolderPath)
        for storageTankName in storageTankNames:
            ddckFileRelativePath = f"ddck/{storageTankName}/{storageTankName}.ddck"
            helper.ensureFilesAreEqual(ddckFileRelativePath, shallReplaceRandomizedFlowRates=False)

            ddcxFileRelativePath = f"ddck/{storageTankName}/{storageTankName}.ddcx"
            helper.ensureFilesAreEqual(ddcxFileRelativePath, shallReplaceRandomizedFlowRates=False)
Exemple #3
0
def test_change_container_pass(happi_client: Client, item: str, target: str,
                               request: pytest.FixtureRequest):
    i = request.getfixturevalue(item)
    t = request.getfixturevalue(target)
    kw = happi_client.change_container(i, t)

    for k in kw:
        assert i.post()[k] == kw[k]
Exemple #4
0
def bokeh_server(request: pytest.FixtureRequest, log_file: IO[str]) -> str:
    bokeh_port: int = request.config.option.bokeh_port

    cmd = ["python", "-m", "bokeh", "serve"]
    argv = [f"--port={bokeh_port}"]
    bokeh_server_url = f"http://localhost:{bokeh_port}"

    env = os.environ.copy()
    env['BOKEH_MINIFIED'] = 'false'

    try:
        proc = subprocess.Popen(cmd + argv,
                                env=env,
                                stdout=log_file,
                                stderr=log_file)
    except OSError:
        write(f"Failed to run: {' '.join(cmd + argv)}")
        sys.exit(1)
    else:
        # Add in the clean-up code
        def stop_bokeh_server() -> None:
            write("Shutting down bokeh-server ...")
            proc.kill()

        request.addfinalizer(stop_bokeh_server)

        def wait_until(func: Callable[[], Any],
                       timeout: float = 5.0,
                       interval: float = 0.01) -> bool:
            start = time.time()

            while True:
                if func():
                    return True
                if time.time() - start > timeout:
                    return False
                time.sleep(interval)

        def wait_for_bokeh_server() -> bool:
            def helper() -> Any:
                if proc.returncode is not None:
                    return True
                try:  # type: ignore[unreachable] # XXX: typeshed bug, proc.returncode: int
                    return requests.get(bokeh_server_url)
                except ConnectionError:
                    return False

            return wait_until(helper)

        if not wait_for_bokeh_server():
            write(f"Timeout when running: {' '.join(cmd + argv)}")
            sys.exit(1)

        if proc.returncode is not None:
            write(f"bokeh server exited with code {proc.returncode}")
            sys.exit(1)

        return bokeh_server_url  # type: ignore[unreachable] # XXX: typeshed bug, proc.returncode: int
Exemple #5
0
def test_setup_test_db_creates_db(request: FixtureRequest, db_dsn):
    server_dsn, dsn, db_name = db_dsn

    request.getfixturevalue("setup_test_db")

    with create_db_client(dsn) as db_client:
        assert len(db_client.query("SELECT File")) == 0

    with create_db_client(server_dsn) as db_client:
        db_client.execute(f"DROP DATABASE {db_name};")
Exemple #6
0
def projects_path(user_path: pathlib.Path,
                  request: pytest.FixtureRequest) -> pathlib.Path:
    """User's local checkouts and clones. Emphemeral directory."""
    dir = user_path / "projects"
    dir.mkdir(exist_ok=True)

    def clean() -> None:
        shutil.rmtree(dir)

    request.addfinalizer(clean)
    return dir
Exemple #7
0
def remote_repos_path(user_path: pathlib.Path,
                      request: pytest.FixtureRequest) -> pathlib.Path:
    """System's remote (file-based) repos to clone andpush to. Emphemeral directory."""
    dir = user_path / "remote_repos"
    dir.mkdir(exist_ok=True)

    def clean() -> None:
        shutil.rmtree(dir)

    request.addfinalizer(clean)
    return dir
Exemple #8
0
def test_recreates_db_and_applies_migration(request: FixtureRequest, db_dsn):
    server_dsn, dsn, db_name = db_dsn
    with create_db_client(server_dsn) as db_client:
        db_client.execute(f"CREATE DATABASE {db_name};")

    request.getfixturevalue("setup_test_db")

    with create_db_client(dsn) as db_client:
        assert len(db_client.query("SELECT File")) == 0

    with create_db_client(server_dsn) as db_client:
        db_client.execute(f"DROP DATABASE {db_name};")
def screenshoter(browser: WebDriver, request: pytest.FixtureRequest):
    class Screenshoter:
        def save(self):
            if SCREENSHOTS_FOLDER is None:
                return
            assert browser.save_screenshot(
                f'{SCREENSHOTS_FOLDER}/{time.time_ns()}_{request.node.name}.png'
            )

    instance = Screenshoter()
    request.addfinalizer(lambda: instance.save())
    return instance
Exemple #10
0
def db_client_or_tx(request: FixtureRequest):
    """
    Yield either a `tx` or a `db_client` fixture depending on `pytest.mark.database`
    params.
    """
    marker = request.node.get_closest_marker("database")
    if not marker:
        raise RuntimeError("Access to database without `database` marker!")

    if marker.kwargs.get("transaction", False):
        yield request.getfixturevalue("db_client")
    else:
        yield request.getfixturevalue("tx")
Exemple #11
0
def test_file_path_and_url(request: pytest.FixtureRequest,
                           file_server: SimpleWebServer) -> Tuple[str, str]:
    filename = request.function.__name__ + '.html'
    file_obj = request.fspath.dirpath().join(filename)
    file_path = file_obj.strpath
    url = file_path.replace('\\', '/')  # Windows-proof

    def tear_down() -> None:
        if file_obj.isfile():
            file_obj.remove()

    request.addfinalizer(tear_down)

    return file_path, file_server.where_is(url)
Exemple #12
0
def _config(request: FixtureRequest):
    """Fixture that parametrizes the configuration used in the tests below."""
    test_name = request.function.__name__

    model_name: str = request.param  # type: ignore
    model_type = ng.optimizers.registry[model_name]

    if model_name in NOT_WORKING:
        pytest.skip(reason=f"Model {model_name} is not supported.")

    tweaks = MODEL_NAMES[model_name]

    if model_type.no_parallelization:
        num_workers = 1
    else:
        num_workers = 10

    TestNevergradOptimizer.config["model_name"] = model_name
    TestNevergradOptimizer.config["num_workers"] = num_workers

    mark = tweaks.get(test_name, None)
    current_phase: TestPhase = request.getfixturevalue("phase")

    if (mark and test_name == "test_seed_rng_init"
            and request.getfixturevalue("phase").n_trials > 0
            and mark in _deterministic_first_point.values()):
        # Remove the mark, because The algo always gives back the same first trial, regardless of
        # the seed. This means that since `test_seed_rng_init` expects different seeds to give
        # different results, the test will fail if we're at the first phase, but pass in other
        # phases.
        mark = None

    if model_name == "MultiScaleCMA" and test_name == "test_state_dict":
        # NOTE: Only fails at the optimization phase.
        if current_phase.n_trials == 0:
            mark = None

    if mark == "skip":
        pytest.skip(reason="Skipping test")
    elif mark:
        request.node.add_marker(mark)

    start = TestNevergradOptimizer.max_trials
    if model_name == "MultiScaleCMA" and test_name == "test_state_dict":
        TestNevergradOptimizer.max_trials = 20

    yield

    TestNevergradOptimizer.max_trials = start
Exemple #13
0
def test_recreates_db(request: FixtureRequest, db_dsn):
    server_dsn, dsn, db_name = db_dsn
    with create_db_client(server_dsn) as db_client:
        db_client.execute(f"CREATE DATABASE {db_name};")

    request.getfixturevalue("setup_test_db")

    with create_db_client(dsn) as db_client:
        with pytest.raises(edgedb.InvalidReferenceError) as excinfo:
            assert len(db_client.query("SELECT File")) == 0

    assert str(excinfo.value) == "object type or alias 'default::File' does not exist"

    with create_db_client(server_dsn) as db_client:
        db_client.execute(f"DROP DATABASE {db_name};")
def possibly_skip_test(request: pytest.FixtureRequest,
                       info: dict[str, Any]) -> dict[str, Any]:
    if "segfault" in info:
        pytest.skip(f"known segfault: {info['segfault']}")

    if "xfail" in info:
        reason = info["xfail"]
        if request.config.option.run_xfail:
            request.applymarker(
                pytest.mark.xfail(
                    run=False,
                    reason=f"known failure: {reason}",
                ))
        else:
            pytest.xfail(f"known failure: {reason}")
    return info
    def postgresql_factory(request: FixtureRequest) -> Iterator[connection]:
        """
        Fixture factory for PostgreSQL.

        :param request: fixture request object
        :returns: postgresql client
        """
        check_for_psycopg()
        proc_fixture: Union[PostgreSQLExecutor,
                            NoopExecutor] = request.getfixturevalue(
                                process_fixture_name)

        pg_host = proc_fixture.host
        pg_port = proc_fixture.port
        pg_user = proc_fixture.user
        pg_password = proc_fixture.password
        pg_options = proc_fixture.options
        pg_db = dbname or proc_fixture.dbname
        pg_load = load or []

        with DatabaseJanitor(pg_user, pg_host, pg_port, pg_db,
                             proc_fixture.version, pg_password,
                             isolation_level) as janitor:
            db_connection: connection = psycopg.connect(
                dbname=pg_db,
                user=pg_user,
                password=pg_password,
                host=pg_host,
                port=pg_port,
                options=pg_options,
            )
            for load_element in pg_load:
                janitor.load(load_element)
            yield db_connection
            db_connection.close()
Exemple #16
0
    def redis_proc_fixture(request: FixtureRequest,
                           tmp_path_factory: TempPathFactory):
        """
        Fixture for pytest-redis.

        #. Get configs.
        #. Run redis process.
        #. Stop redis process after tests.

        :param request: fixture request object
        :param tmpdir_factory:
        :rtype: pytest_redis.executors.TCPExecutor
        :returns: tcp executor
        """
        config = get_config(request)
        redis_exec = executable or config["exec"]
        rdbcompression = config[
            "compression"] if compression is None else compression
        rdbchecksum = config["rdbchecksum"] if checksum is None else checksum

        if datadir:
            redis_datadir = Path(datadir)
        elif config["datadir"]:
            redis_datadir = Path(config["datadir"])
        else:
            redis_datadir = tmp_path_factory.mktemp(
                f"pytest-redis-{request.fixturename}")

        redis_executor = RedisExecutor(
            executable=redis_exec,
            databases=db_count or config["db_count"],
            redis_timeout=timeout or config["timeout"],
            loglevel=loglevel or config["loglevel"],
            rdbcompression=rdbcompression,
            rdbchecksum=rdbchecksum,
            syslog_enabled=syslog or config["syslog"],
            save=save or config["save"],
            host=host or config["host"],
            port=get_port(port) or get_port(config["port"]),
            timeout=60,
            datadir=redis_datadir,
        )
        redis_executor.start()
        request.addfinalizer(redis_executor.stop)

        return redis_executor
Exemple #17
0
def output_file_url(request: pytest.FixtureRequest,
                    file_server: SimpleWebServer) -> str:
    from bokeh.io import output_file
    filename = request.function.__name__ + '.html'
    file_obj = request.fspath.dirpath().join(filename)
    file_path = file_obj.strpath
    url = file_path.replace('\\', '/')  # Windows-proof

    output_file(file_path, mode='inline')

    def tear_down() -> None:
        if file_obj.isfile():
            file_obj.remove()

    request.addfinalizer(tear_down)

    return file_server.where_is(url)
def _push_request_context(request: pytest.FixtureRequest):
    """During tests execution request context has been pushed, e.g. `url_for`,
    `session`, etc. can be used in tests as is::

        def test_app(app, client):
            assert client.get(url_for('myview')).status_code == 200

    """
    if "app" not in request.fixturenames:
        return
    app = request.getfixturevalue("app")
    ctx = app.test_request_context()
    ctx.push()

    def teardown():
        ctx.pop()

    request.addfinalizer(teardown)
Exemple #19
0
async def session_context(
    request: pytest.FixtureRequest,
    event_loop: AbstractEventLoop,
) -> contextvars.Context:
    if "nosession" not in request.keywords:
        await initialize_connection("spellbot-test", use_transaction=True)

        test_session = db_session_maker()
        DatabaseSession.set(test_session)  # type: ignore

        BlockFactory._meta.sqlalchemy_session = DatabaseSession  # type: ignore
        ChannelFactory._meta.sqlalchemy_session = DatabaseSession  # type: ignore
        ConfigFactory._meta.sqlalchemy_session = DatabaseSession  # type: ignore
        GameFactory._meta.sqlalchemy_session = DatabaseSession  # type: ignore
        GuildAwardFactory._meta.sqlalchemy_session = DatabaseSession  # type: ignore
        GuildFactory._meta.sqlalchemy_session = DatabaseSession  # type: ignore
        PlayFactory._meta.sqlalchemy_session = DatabaseSession  # type: ignore
        UserAwardFactory._meta.sqlalchemy_session = DatabaseSession  # type: ignore
        UserFactory._meta.sqlalchemy_session = DatabaseSession  # type: ignore
        VerifyFactory._meta.sqlalchemy_session = DatabaseSession  # type: ignore
        WatchFactory._meta.sqlalchemy_session = DatabaseSession  # type: ignore

        def cleanup_session():
            async def finalizer() -> None:
                try:
                    await rollback_transaction()
                except Exception:  # pragma: no cover
                    pass

            event_loop.run_until_complete(finalizer())

        request.addfinalizer(cleanup_session)

    context = contextvars.copy_context()

    def cleanup_context():
        nonlocal context
        for c in context:
            c.set(context[c])

    request.addfinalizer(cleanup_context)

    return context
Exemple #20
0
def reports_test_dir(request: pytest.FixtureRequest) -> str:
    """Returns the relative path to the test reports directory.

    Also, adds a finaliser function to remove the test reports directory and its contents
    once the test has completed.

    Args:
        request (pytest.FixtureRequest): A pytest fixture providing information of the requesting test function.

    Returns:
        str: The relative path to the test reports directory.
    """
    # Declare relative path to the test reports directory
    # to be created in the relevant tests
    reports_test_dir_path = "./src/test/reports"
    # Add finalizer function to remove the test reports directory
    # and its contents once the test has finished
    def remove_reports_dir_contents():
        shutil.rmtree(reports_test_dir_path)

    # Add finalizer function
    request.addfinalizer(remove_reports_dir_contents)
    return reports_test_dir_path
Exemple #21
0
    def update_parameterized(cls, request: pytest.FixtureRequest,
                             config: BaseConfig):
        """Update the given configuration object with parameterized values if the key is present"""

        config_type = config.__class__.__name__
        parameterized_keys = cls._get_parameterized_keys(request)

        for fixture_name in parameterized_keys:
            with suppress(pytest.FixtureLookupError, AttributeError):
                if hasattr(config, fixture_name):
                    value = request.getfixturevalue(fixture_name)
                    config.set_value(fixture_name, value)

                    log.debug(
                        f"{config_type}.{fixture_name} value updated from parameterized value to {value}"
                    )
                else:
                    raise AttributeError(
                        f"No attribute name {fixture_name} in {config_type} object type"
                    )
Exemple #22
0
async def flush_db_if_needed(request: FixtureRequest):
    """Flush database after each tests."""
    try:
        yield
    finally:
        marker = request.node.get_closest_marker("database")
        if not marker:
            return

        if not marker.kwargs.get("transaction", False):
            return

        session_db_client: DBClient = request.getfixturevalue(
            "session_db_client")
        await session_db_client.execute("""
            DELETE Account;
            DELETE File;
            DELETE MediaType;
            DELETE Namespace;
            DELETE User;
        """)
Exemple #23
0
def test_change_container_fail(happi_client: Client, item: str, target: str,
                               request: pytest.FixtureRequest):
    i = request.getfixturevalue(item)
    t = request.getfixturevalue(target)
    with pytest.raises(TransferError):
        happi_client.change_container(i, t)
def file_server(request: pytest.FixtureRequest) -> SimpleWebServer:
    server = SimpleWebServer()
    server.start()
    request.addfinalizer(server.stop)
    return server
def device(request: pytest.FixtureRequest) -> DysonDevice:
    with patch(f"{MODULE}._async_get_platforms", return_value=["sensor"]):
        yield request.param()
Exemple #26
0
async def recorder(
    request: pytest.FixtureRequest, ) -> typing.Optional[RecorderFixture]:
    is_unittest_class = request.cls is not None

    marker = request.node.get_closest_marker("recorder")
    if not is_unittest_class and marker is None:
        return None

    if is_unittest_class:
        cassette_library_dir = os.path.join(
            CASSETTE_LIBRARY_DIR_BASE,
            request.cls.__name__,
            request.node.name,
        )
    else:
        cassette_library_dir = os.path.join(
            CASSETTE_LIBRARY_DIR_BASE,
            request.node.module.__name__.replace(
                "mergify_engine.tests.functional.", "").replace(".", "/"),
            request.node.name,
        )

    # Recording stuffs
    if RECORD:
        if os.path.exists(cassette_library_dir):
            shutil.rmtree(cassette_library_dir)
        os.makedirs(cassette_library_dir)

    recorder = vcr.VCR(
        cassette_library_dir=cassette_library_dir,
        record_mode="all" if RECORD else "none",
        match_on=["method", "uri"],
        ignore_localhost=True,
        filter_headers=[
            ("Authorization", "<TOKEN>"),
            ("X-Hub-Signature", "<SIGNATURE>"),
            ("User-Agent", None),
            ("Accept-Encoding", None),
            ("Connection", None),
        ],
        before_record_response=pyvcr_response_filter,
        before_record_request=pyvcr_request_filter,
    )

    if RECORD:
        github.CachedToken.STORAGE = {}
    else:
        # Never expire token during replay
        patcher = mock.patch.object(github_app,
                                    "get_or_create_jwt",
                                    return_value="<TOKEN>")
        patcher.start()
        request.addfinalizer(patcher.stop)
        patcher = mock.patch.object(
            github.GithubAppInstallationAuth,
            "get_access_token",
            return_value="<TOKEN>",
        )
        patcher.start()
        request.addfinalizer(patcher.stop)

    # Let's start recording
    cassette = recorder.use_cassette("http.json")
    cassette.__enter__()
    request.addfinalizer(cassette.__exit__)
    record_config_file = os.path.join(cassette_library_dir, "config.json")

    if RECORD:
        with open(record_config_file, "w") as f:
            f.write(
                json.dumps(
                    RecordConfigType({
                        "organization_id":
                        config.TESTING_ORGANIZATION_ID,
                        "organization_name":
                        config.TESTING_ORGANIZATION_NAME,
                        "repository_id":
                        config.TESTING_REPOSITORY_ID,
                        "repository_name":
                        github_types.GitHubRepositoryName(
                            config.TESTING_REPOSITORY_NAME),
                        "branch_prefix":
                        datetime.datetime.utcnow().strftime("%Y%m%d%H%M%S"),
                    })))

    with open(record_config_file, "r") as f:
        return RecorderFixture(
            typing.cast(RecordConfigType, json.loads(f.read())), recorder)
Exemple #27
0
def test_automatic_mock_event_emission(
    server_url_fixture: str,
    mock_client_wait_timeout: float,
    mock_client_wait_interval: float,
    client: socketio.Client,
    request: pytest.FixtureRequest,
):
    server_url: str = request.getfixturevalue(server_url_fixture)
    new_message_event = "new message"
    new_message_mock_ack = Mock()

    @client.on(new_message_event)
    def _new_message_handler(data):
        jsonschema.validate(data, {
            "username": {
                "type": "string"
            },
            "message": {
                "type": "string"
            }
        })

        # Assert that message is of sentence format:
        assert data["message"].endswith(".")
        assert " " in data["message"]

        # Assert that username is a first name:
        assert data["username"].istitle()

        new_message_mock_ack(new_message_event)

    typing_event = "typing"
    typing_mock_ack = Mock()

    @client.on(typing_event)
    def _typing_handler(data):
        jsonschema.validate(data, {"username": {"type": "string"}})

        # Assert that username is a first name:
        assert data["username"].istitle()
        typing_mock_ack(typing_event)

    user_joined_event = "user joined"
    user_joined_mock_ack = Mock()

    @client.on(user_joined_event)
    def _user_joined_handler(data):
        jsonschema.validate(data, {
            "username": {
                "type": "string"
            },
            "numUsers": {
                "type": "integer"
            }
        })

        # Assert that username is a first name:
        assert data["username"].istitle()
        user_joined_mock_ack(user_joined_event)

    client.connect(server_url, wait_timeout=mock_client_wait_timeout)
    # Wait for all messages to arrive:
    client.sleep(mock_client_wait_interval)

    new_message_mock_ack.assert_called_with(new_message_event)
    typing_mock_ack.assert_called_with(typing_event)
    user_joined_mock_ack.assert_called_with(user_joined_event)
Exemple #28
0
async def dashboard(redis_cache: redis_utils.RedisCache,
                    request: pytest.FixtureRequest) -> DashboardFixture:
    is_unittest_class = request.cls is not None
    subscription_active = False
    marker = request.node.get_closest_marker("subscription")
    if marker:
        subscription_active = marker.args[0]
    elif is_unittest_class:
        subscription_active = request.cls.SUBSCRIPTION_ACTIVE

    api_key_admin = "a" * 64

    sub = subscription.Subscription(
        redis_cache,
        config.TESTING_ORGANIZATION_ID,
        "You're not nice",
        frozenset(
            getattr(subscription.Features, f)
            for f in subscription.Features.__members__) if subscription_active
        else frozenset([subscription.Features.PUBLIC_REPOSITORY]),
    )
    await sub._save_subscription_to_cache()
    user_tokens = user_tokens_mod.UserTokens(
        redis_cache,
        config.TESTING_ORGANIZATION_ID,
        [
            {
                "id": github_types.GitHubAccountIdType(config.ORG_ADMIN_ID),
                "login": github_types.GitHubLogin("mergify-test1"),
                "oauth_access_token": config.ORG_ADMIN_GITHUB_APP_OAUTH_TOKEN,
                "name": None,
                "email": None,
            },
            {
                "id": github_types.GitHubAccountIdType(config.ORG_USER_ID),
                "login": github_types.GitHubLogin("mergify-test4"),
                "oauth_access_token": config.ORG_USER_PERSONAL_TOKEN,
                "name": None,
                "email": None,
            },
        ],
    )
    await typing.cast(user_tokens_mod.UserTokensSaas,
                      user_tokens).save_to_cache()

    real_get_subscription = subscription.Subscription.get_subscription

    async def fake_retrieve_subscription_from_db(redis_cache, owner_id):
        if owner_id == config.TESTING_ORGANIZATION_ID:
            return sub
        return subscription.Subscription(
            redis_cache,
            owner_id,
            "We're just testing",
            set(subscription.Features.PUBLIC_REPOSITORY),
        )

    async def fake_subscription(redis_cache, owner_id):
        if owner_id == config.TESTING_ORGANIZATION_ID:
            return await real_get_subscription(redis_cache, owner_id)
        return subscription.Subscription(
            redis_cache,
            owner_id,
            "We're just testing",
            set(subscription.Features.PUBLIC_REPOSITORY),
        )

    patcher = mock.patch(
        "mergify_engine.dashboard.subscription.Subscription._retrieve_subscription_from_db",
        side_effect=fake_retrieve_subscription_from_db,
    )
    patcher.start()
    request.addfinalizer(patcher.stop)

    patcher = mock.patch(
        "mergify_engine.dashboard.subscription.Subscription.get_subscription",
        side_effect=fake_subscription,
    )
    patcher.start()
    request.addfinalizer(patcher.stop)

    async def fake_retrieve_user_tokens_from_db(redis_cache, owner_id):
        if owner_id == config.TESTING_ORGANIZATION_ID:
            return user_tokens
        return user_tokens_mod.UserTokens(redis_cache, owner_id, {})

    real_get_user_tokens = user_tokens_mod.UserTokens.get

    async def fake_user_tokens(redis_cache, owner_id):
        if owner_id == config.TESTING_ORGANIZATION_ID:
            return await real_get_user_tokens(redis_cache, owner_id)
        return user_tokens_mod.UserTokens(redis_cache, owner_id, {})

    patcher = mock.patch(
        "mergify_engine.dashboard.user_tokens.UserTokensSaas._retrieve_from_db",
        side_effect=fake_retrieve_user_tokens_from_db,
    )
    patcher.start()
    request.addfinalizer(patcher.stop)

    patcher = mock.patch(
        "mergify_engine.dashboard.user_tokens.UserTokensSaas.get",
        side_effect=fake_user_tokens,
    )
    patcher.start()
    request.addfinalizer(patcher.stop)

    async def fake_application_get(redis_cache, api_access_key, api_secret_key,
                                   account_scope):
        if (api_access_key == api_key_admin[:32]
                and api_secret_key == api_key_admin[32:]):
            return application_mod.Application(
                redis_cache,
                123,
                "testing application",
                api_access_key,
                api_secret_key,
                account_scope={
                    "id": config.TESTING_ORGANIZATION_ID,
                    "login": config.TESTING_ORGANIZATION_NAME,
                },
            )
        raise application_mod.ApplicationUserNotFound()

    patcher = mock.patch(
        "mergify_engine.dashboard.application.ApplicationSaas.get",
        side_effect=fake_application_get,
    )
    patcher.start()
    request.addfinalizer(patcher.stop)

    return DashboardFixture(
        api_key_admin,
        sub,
        user_tokens,
    )
Exemple #29
0
def cfg_all(request: FixtureRequest):
    return request.getfixturevalue(request.param)
Exemple #30
0
def app(request: pytest.FixtureRequest) -> "SecurityFixture":
    app = SecurityFixture(__name__)
    app.response_class = Response
    app.debug = True
    app.config["SECRET_KEY"] = "secret"
    app.config["TESTING"] = True
    app.config["LOGIN_DISABLED"] = False
    app.config["WTF_CSRF_ENABLED"] = False
    # Our test emails/domain isn't necessarily valid
    app.config["SECURITY_EMAIL_VALIDATOR_ARGS"] = {"check_deliverability": False}
    app.config["SECURITY_TWO_FACTOR_SECRET"] = {
        "1": "TjQ9Qa31VOrfEzuPy4VHQWPCTmRzCnFzMKLxXYiZu9B"
    }
    app.config["SECURITY_SMS_SERVICE"] = "test"
    app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False

    app.config["SECURITY_PASSWORD_SALT"] = "salty"
    # Make this plaintext for most tests - reduces unit test time by 50%
    app.config["SECURITY_PASSWORD_HASH"] = "plaintext"
    # Make this hex_md5 for token tests
    app.config["SECURITY_HASHING_SCHEMES"] = ["hex_md5"]
    app.config["SECURITY_DEPRECATED_HASHING_SCHEMES"] = []

    for opt in [
        "changeable",
        "recoverable",
        "registerable",
        "trackable",
        "passwordless",
        "confirmable",
        "two_factor",
        "unified_signin",
        "webauthn",
    ]:
        app.config["SECURITY_" + opt.upper()] = opt in request.keywords

    pytest_major = int(pytest.__version__.split(".")[0])
    if pytest_major >= 4:
        marker_getter = request.node.get_closest_marker
    else:
        marker_getter = request.keywords.get
    settings = marker_getter("settings")
    if settings is not None:
        for key, value in settings.kwargs.items():
            app.config["SECURITY_" + key.upper()] = value

    app.mail = Mail(app)  # type: ignore
    app.json_encoder = JSONEncoder

    # use babel marker to signify tests that need babel extension.
    babel = marker_getter("babel")
    if babel:
        if NO_BABEL:
            raise pytest.skip("Requires Babel")
        Babel(app)

    @app.route("/")
    def index():
        return render_template("index.html", content="Home Page")

    @app.route("/profile")
    @auth_required()
    def profile():
        if hasattr(app, "security"):
            if app.security._want_json(flask_request):
                return jsonify(message="profile")

        return render_template("index.html", content="Profile Page")

    @app.route("/post_login")
    @login_required
    def post_login():
        return render_template("index.html", content="Post Login")

    @app.route("/http")
    @http_auth_required
    def http():
        return "HTTP Authentication"

    @app.route("/http_admin_required")
    @http_auth_required
    @permissions_required("admin")
    def http_admin_required():
        assert get_request_attr("fs_authn_via") == "basic"
        return "HTTP Authentication"

    @app.route("/http_custom_realm")
    @http_auth_required("My Realm")
    def http_custom_realm():
        assert get_request_attr("fs_authn_via") == "basic"
        return render_template("index.html", content="HTTP Authentication")

    @app.route("/token", methods=["GET", "POST"])
    @auth_token_required
    def token():
        assert get_request_attr("fs_authn_via") == "token"
        return render_template("index.html", content="Token Authentication")

    @app.route("/multi_auth")
    @auth_required("session", "token", "basic")
    def multi_auth():
        return render_template("index.html", content="Session, Token, Basic auth")

    @app.route("/post_logout")
    def post_logout():
        return render_template("index.html", content="Post Logout")

    @app.route("/post_register")
    def post_register():
        return render_template("index.html", content="Post Register")

    @app.route("/post_confirm")
    def post_confirm():
        return render_template("index.html", content="Post Confirm")

    @app.route("/admin")
    @roles_required("admin")
    def admin():
        assert get_request_attr("fs_authn_via") == "session"
        return render_template("index.html", content="Admin Page")

    @app.route("/admin_and_editor")
    @roles_required("admin", "editor")
    def admin_and_editor():
        return render_template("index.html", content="Admin and Editor Page")

    @app.route("/admin_or_editor")
    @roles_accepted("admin", "editor")
    def admin_or_editor():
        return render_template("index.html", content="Admin or Editor Page")

    @app.route("/simple")
    @roles_accepted("simple")
    def simple():
        return render_template("index.html", content="SimplePage")

    @app.route("/admin_perm")
    @permissions_accepted("full-write", "super")
    def admin_perm():
        return render_template(
            "index.html", content="Admin Page with full-write or super"
        )

    @app.route("/admin_perm_required")
    @permissions_required("full-write", "super")
    def admin_perm_required():
        return render_template("index.html", content="Admin Page required")

    @app.route("/page1")
    def page_1():
        return "Page 1"

    @app.route("/json", methods=["GET", "POST"])
    def echo_json():
        return jsonify(flask_request.get_json())

    @app.route("/unauthz", methods=["GET", "POST"])
    def unauthz():
        return render_template("index.html", content="Unauthorized")

    @app.route("/fresh", methods=["GET", "POST"])
    @auth_required(within=60)
    def fresh():
        if app.security._want_json(flask_request):
            return jsonify(title="Fresh Only")
        else:
            return render_template("index.html", content="Fresh Only")

    def revert_forms():
        # Some forms/tests have dynamic fields - be sure to revert them.
        if hasattr(app, "security"):
            if hasattr(app.security.login_form, "email"):
                del app.security.login_form.email
            if hasattr(app.security.register_form, "username"):
                del app.security.register_form.username
            if hasattr(app.security.confirm_register_form, "username"):
                del app.security.confirm_register_form.username

    request.addfinalizer(revert_forms)
    return app