Ejemplo n.º 1
0
def block_network(request: SubRequest, record_mode: str,
                  vcr_markers: List[Mark]) -> Iterator[None]:
    """Block network access in tests except for "none" VCR recording mode."""
    marker = request.node.get_closest_marker(name="block_network")
    if marker is not None:
        validate_block_network_mark(marker)
    # If network blocking is enabled there is one exception - if VCR is in recording mode (any mode except "none")
    default_block = marker or request.config.getoption("--block-network")
    allowed_hosts = getattr(
        marker, "kwargs",
        {}).get("allowed_hosts") or request.config.getoption("--allowed-hosts")
    if isinstance(allowed_hosts, str):
        allowed_hosts = allowed_hosts.split(",")
    if vcr_markers:
        # Take `record_mode` with the most priority:
        #  - Explicit CLI option
        #  - The `vcr_config` fixture
        #  - The `vcr` mark
        config = request.getfixturevalue("vcr_config")
        merged_config = merge_kwargs(config, vcr_markers)
        # If `--record-mode` was not explicitly passed in CLI, then take one from the merged config
        if request.config.getoption("--record-mode") is None:
            record_mode = merged_config.get("record_mode", "none")
    if default_block and (not request.getfixturevalue("vcr_markers")
                          or record_mode == "none"):
        with network.blocking_context(allowed_hosts=allowed_hosts):
            yield
    else:
        yield
Ejemplo n.º 2
0
def record_stdout(
    disable_recording: bool,
    record_stdout_markers: List[Mark],
    record_mode: str,
    request: SubRequest,
):
    marker = request.node.get_closest_marker("record_stdout")

    if disable_recording:
        yield None
    elif marker:
        # SETUP TEST DETAILS
        module_dir = request.node.fspath.dirname
        module_name = request.node.fspath.purebasename
        test_name = request.node.name

        # FORMAT MARKER'S KEYWORD ARGUMENTS
        formatted_kwargs = record_stdout_format_kwargs(
            test_name=test_name,
            record_mode=record_mode,
            record_stdout_markers=record_stdout_markers,
        )

        # SETUP RECORDER
        path_template = PathTemplate(
            module_dir=module_dir,
            module_name=module_name,
            test_name=formatted_kwargs["record_name"],
        )
        recorder = Recorder(path_template=path_template,
                            record_mode=formatted_kwargs["record_mode"])

        # CAPTURE STDOUT
        capture = request.config.getoption("--capture")
        if capture == "no":
            global_capturing = MultiCapture(in_=SysCapture(0),
                                            out=SysCapture(1),
                                            err=SysCapture(2))
            global_capturing.start_capturing()
            yield
            recorder.capture(
                captured=global_capturing.readouterr().out,
                strip=formatted_kwargs["strip"],
            )
            global_capturing.stop_capturing()
        else:
            capsys = request.getfixturevalue("capsys")
            yield
            recorder.capture(captured=capsys.readouterr().out,
                             strip=formatted_kwargs["strip"])

        # SAVE/CHECK RECORD
        if formatted_kwargs["save_record"]:
            recorder.persist()
            recorder.assert_equal()
            recorder.assert_in_list(in_list=formatted_kwargs["assert_in_list"])
        else:
            recorder.assert_in_list(in_list=formatted_kwargs["assert_in_list"])
    else:
        yield None
Ejemplo n.º 3
0
def vcr(  # pylint: disable=too-many-arguments
    request: SubRequest,
    vcr_markers: List[Mark],
    vcr_cassette_dir: str,
    record_mode: str,
    disable_recording: bool,
    pytestconfig: Config,
) -> Iterator[Optional[Cassette]]:
    """Install a cassette if a test is marked with `pytest.mark.vcr`."""
    if disable_recording:
        yield None
    elif vcr_markers:
        config = request.getfixturevalue("vcr_config")
        default_cassette = request.getfixturevalue("default_cassette_name")
        with use_cassette(default_cassette, vcr_cassette_dir, record_mode,
                          vcr_markers, config, pytestconfig) as cassette:
            yield cassette
    else:
        yield None
Ejemplo n.º 4
0
def xfail_if_unseeded_model_chosen(request: SubRequest):
    """Adds a xfail mark on tests that relate to seeding when non-seeded model is chosen."""
    # TODO: Can't refer to the tests by reference because of `parametrize_this`: their `__name__`
    # becomes "method". If/when we rework `parametrize_this` or change the name of tests, it'll
    # be important to update this as well.
    # NOTE: Also, normally I'd add a fixture to only these tests, but that's not currently possible:
    # It seems like the signature of these tests can't be changed, because of `@phase` and/or
    # `parametrize_this`.
    tests_that_check_seeding = [
        TestHEBO.test_seed_rng,
        TestHEBO.test_seed_rng_init,
        TestHEBO.test_state_dict,
    ]

    model_name: str = request.getfixturevalue("model_name")

    if model_name in properly_seeded_models:
        return  # Don't add any mark, the test is expected to pass.

    # NOTE: We need to detect the phase. The reason for this is so we can avoid having a
    # bunch of tests XPASS when the test is ran in the random phase (where some do work).
    if "num" not in request.fixturenames:
        return  # One of the tests that doesn't involve the phase.

    in_random_phase: bool = request.getfixturevalue("num") == 0
    if in_random_phase:
        return

    # NOTE: Also can't use `request.function` because of `parametrize_this`, since it points
    # to the local closure inside `parametrize_this`.
    # if request.function in test_that_check_seeding:
    if any(func == request.function for func in tests_that_check_seeding):
        request.node.add_marker(
            pytest.mark.xfail(
                reason=f"This model name {model_name} is not properly seeded.",
            ))
Ejemplo n.º 5
0
def block_network(request: SubRequest, record_mode: str) -> Iterator[None]:
    """Block network access in tests except for "none" VCR recording mode."""
    marker = request.node.get_closest_marker(name="block_network")
    # If network blocking is enabled there is one exception - if VCR is in recording mode (any mode except "none")
    default_block = marker or request.config.getoption("--block-network")
    allowed_hosts = getattr(
        marker, "kwargs",
        {}).get("allowed_hosts") or request.config.getoption("--allowed-hosts")
    if isinstance(allowed_hosts, str):
        allowed_hosts = allowed_hosts.split(",")
    if default_block and (not request.getfixturevalue("vcr_markers")
                          or record_mode == "none"):
        with network.blocking_context(allowed_hosts=allowed_hosts):
            yield
    else:
        yield
Ejemplo n.º 6
0
    def test_after_processing(
        self,
        vws_client: VWS,
        request: SubRequest,
        image_fixture_name: str,
        expected_status: TargetStatuses,
    ) -> None:
        """
        After processing is completed, the tracking rating is in the range of
        0 to 5.

        The documentation says:

        > Note: tracking_rating and reco_rating are provided only when
        > status = success.

        However, this shows that ``tracking_rating`` is given when the status
        is not success.
        It also shows that ``reco_rating`` is not provided even when the status
        is success.
        """
        image_file = request.getfixturevalue(image_fixture_name)

        target_id = vws_client.add_target(
            name='example',
            width=1,
            image=image_file,
            active_flag=True,
            application_metadata=None,
        )

        # The tracking rating may change during processing.
        # Therefore we wait until processing ends.
        vws_client.wait_for_target_processed(target_id=target_id)

        report = vws_client.get_target_summary_report(target_id=target_id)
        target_details = vws_client.get_target_record(target_id=target_id)

        tracking_rating = target_details.target_record.tracking_rating
        assert report.tracking_rating == tracking_rating
        assert report.tracking_rating in range(6)
        assert report.status == expected_status
        assert report.total_recos == 0
        assert report.current_month_recos == 0
        assert report.previous_month_recos == 0
Ejemplo n.º 7
0
 def auto_convert_schema(self, request: SubRequest) -> None:
     if hasattr(self, "schema") and isinstance(getattr(self, "schema"), str):
         setattr(self, "schema", request.getfixturevalue(getattr(self, "schema")))
Ejemplo n.º 8
0
def endpoint(request: SubRequest) -> Endpoint:
    """
    Return details of an endpoint for the Target API or the Query API.
    """
    endpoint_fixture: Endpoint = request.getfixturevalue(request.param)
    return endpoint_fixture