Пример #1
0
    def test_tmppath_relative_basetemp_absolute(self, tmp_path, monkeypatch):
        from _pytest.tmpdir import TempPathFactory

        monkeypatch.chdir(tmp_path)
        config = FakeConfig("hello")
        t = TempPathFactory.from_config(config)
        assert t.getbasetemp().resolve() == (tmp_path / "hello").resolve()
Пример #2
0
    def test_mktemp(self, tmp_path):

        from _pytest.tmpdir import TempdirFactory, TempPathFactory

        config = FakeConfig(tmp_path)
        t = TempdirFactory(TempPathFactory.from_config(config))
        tmp = t.mktemp("world")
        assert tmp.relto(t.getbasetemp()) == "world0"
        tmp = t.mktemp("this")
        assert tmp.relto(t.getbasetemp()).startswith("this")
        tmp2 = t.mktemp("this")
        assert tmp2.relto(t.getbasetemp()).startswith("this")
        assert tmp2 != tmp
Пример #3
0
def tmp_enc_dec_dir(tmp_path_factory: TempPathFactory) -> Path:
    return tmp_path_factory.mktemp("enc_dec_dir")
Пример #4
0
def logging_conf_tmp_file_path(tmp_path_factory: TempPathFactory) -> Path:
    """Copy logging configuration module to custom temporary location."""
    tmp_dir = tmp_path_factory.mktemp("tmp_log")
    shutil.copy(Path(logging_conf_module.__file__),
                Path(f"{tmp_dir}/tmp_log.py"))
    return tmp_dir
Пример #5
0
def tmp_root_homes_dir_enc_dec_ro(tmp_path_factory: TempPathFactory) -> Path:
    return tmp_path_factory.mktemp("root-homes-enc-dec-ro")
Пример #6
0
def test_create_model_package(
    tmp_path_factory: TempPathFactory,
    domain: Domain,
):
    train_model_storage = LocalModelStorage(
        tmp_path_factory.mktemp("train model storage"))

    train_schema = GraphSchema({
        "train":
        SchemaNode(
            needs={},
            uses=PersistableTestComponent,
            fn="train",
            constructor_name="create",
            config={
                "some_config": 123455,
                "some more config": [{
                    "nested": "hi"
                }]
            },
        ),
        "load":
        SchemaNode(
            needs={"resource": "train"},
            uses=PersistableTestComponent,
            fn="run_inference",
            constructor_name="load",
            config={},
            is_target=True,
        ),
    })

    predict_schema = GraphSchema({
        "run":
        SchemaNode(
            needs={},
            uses=PersistableTestComponent,
            fn="run",
            constructor_name="load",
            config={
                "some_config": 123455,
                "some more config": [{
                    "nested": "hi"
                }]
            },
        ),
    })

    # Fill model Storage
    with train_model_storage.write_to(Resource("resource1")) as directory:
        file = directory / "file.txt"
        file.write_text("test")

    # Package model
    persisted_model_dir = tmp_path_factory.mktemp("persisted models")
    archive_path = persisted_model_dir / "my-model.tar.gz"

    trained_at = datetime.utcnow()
    with freezegun.freeze_time(trained_at):
        train_model_storage.create_model_package(archive_path, train_schema,
                                                 predict_schema, domain)

    # Unpack and inspect packaged model
    load_model_storage_dir = tmp_path_factory.mktemp("load model storage")

    (
        load_model_storage,
        packaged_metadata,
    ) = LocalModelStorage.from_model_archive(load_model_storage_dir,
                                             archive_path)

    assert packaged_metadata.train_schema == train_schema
    assert packaged_metadata.predict_schema == predict_schema
    assert packaged_metadata.domain.as_dict() == domain.as_dict()

    assert packaged_metadata.rasa_open_source_version == rasa.__version__
    assert packaged_metadata.trained_at == trained_at
    assert packaged_metadata.model_id

    persisted_resources = load_model_storage_dir.glob("*")
    assert list(persisted_resources) == [
        Path(load_model_storage_dir, "resource1")
    ]
Пример #7
0
 def test_tmppath_relative_basetemp_absolute(self, tmp_path, monkeypatch):
     """#4425"""
     monkeypatch.chdir(tmp_path)
     config = cast(Config, FakeConfig("hello"))
     t = TempPathFactory.from_config(config)
     assert t.getbasetemp().resolve() == (tmp_path / "hello").resolve()
Пример #8
0
def test_jieba_load_and_persist_dictionary(
    tmp_path_factory: TempPathFactory,
    default_model_storage: ModelStorage,
    default_execution_context: ExecutionContext,
    caplog: LogCaptureFixture,
):
    dictionary_directory = tmp_path_factory.mktemp("dictionaries")
    dictionary_path = dictionary_directory / "dictionary_1"

    dictionary_contents = """
创新办 3 i
云计算 5
凱特琳 nz
台中
        """
    dictionary_path.write_text(dictionary_contents, encoding="utf-8")

    component_config = {"dictionary_path": dictionary_directory}

    resource = Resource("jieba")
    tk = JiebaTokenizerGraphComponent.create(
        {
            **JiebaTokenizerGraphComponent.get_default_config(),
            **component_config
        },
        default_model_storage,
        resource,
        default_execution_context,
    )

    tk.process_training_data(TrainingData([Message(data={TEXT: ""})]))

    # The dictionary has not been persisted yet.
    with caplog.at_level(logging.WARN):
        JiebaTokenizerGraphComponent.load(
            {
                **JiebaTokenizerGraphComponent.get_default_config(),
                **component_config
            },
            default_model_storage,
            resource,
            default_execution_context,
        )
        assert any(
            "Failed to load JiebaTokenizerGraphComponent from model storage."
            in message for message in caplog.messages)

    tk.persist()

    # Check the persisted dictionary matches the original file.
    with default_model_storage.read_from(resource) as resource_dir:
        contents = (resource_dir / "dictionary_1").read_text(encoding="utf-8")
        assert contents == dictionary_contents

    # Delete original files to show that we read from the model storage.
    dictionary_path.unlink()
    dictionary_directory.rmdir()

    JiebaTokenizerGraphComponent.load(
        {
            **JiebaTokenizerGraphComponent.get_default_config(),
            **component_config
        },
        default_model_storage,
        resource,
        default_execution_context,
    )

    tk.process([Message(data={TEXT: ""})])
Пример #9
0
def _mk_case_x_writable_tmp_fixture_dir(
        idx: int, tmp_path_factory: TempPathFactory) -> Path:
    tmp_dir = tmp_path_factory.mktemp(f"case{idx}_sad")
    tmp_cwd_dir = tmp_dir.joinpath("device-ssh")
    copytree(_get_case_x_repo_src_dir(idx), tmp_cwd_dir)
    return tmp_cwd_dir
Пример #10
0
def _mk_gpg_ctx_w_info_fixture_no_checks(
        gen_fn: _GpgCtxGenFnT, tmp_factory: TempPathFactory,
        request: OptPyTestFixtureRequestT) -> GpgContextWGenInfo:
    home_dir = tmp_factory.mktemp("home_user")
    gpg_ctx = gen_fn(home_dir, request)
    return gpg_ctx
Пример #11
0
def temp_dir(tmp_path_factory: TempPathFactory) -> Path:
    return tmp_path_factory.mktemp("test-tmp")
Пример #12
0
def tmp_user_home_dir(tmp_path_factory: TempPathFactory) -> Path:
    return tmp_path_factory.mktemp("home_user")
Пример #13
0
def tmp_export_dir(tmp_path_factory: TempPathFactory) -> Path:
    return tmp_path_factory.mktemp("export_dir")
Пример #14
0
def tmp_root_homes_dir(tmp_path_factory: TempPathFactory) -> Path:
    return tmp_path_factory.mktemp("root_homes")
Пример #15
0
def gunicorn_conf_tmp_path(tmp_path_factory: TempPathFactory) -> Path:
    """Create temporary directory for Gunicorn configuration file."""
    return tmp_path_factory.mktemp("gunicorn")
Пример #16
0
def app_module_tmp_path(tmp_path_factory: TempPathFactory) -> Path:
    """Copy app modules to temporary directory to test custom app module paths."""
    tmp_dir = tmp_path_factory.mktemp("app")
    shutil.copytree(
        Path(pre_start_module.__file__).parent, Path(f"{tmp_dir}/tmp_app"))
    return tmp_dir
Пример #17
0
def test_constraints_validation(tmp_path_factory: TempPathFactory,
                                rule_runner: RuleRunner) -> None:
    find_links = create_dists(
        tmp_path_factory.mktemp("sdists"),
        Project("Foo-Bar", "1.0.0"),
        Project("Bar", "5.5.5"),
        Project("baz", "2.2.2"),
        Project("QUX", "3.4.5"),
    )

    # Turn the project dir into a git repo, so it can be cloned.
    foorl_dir = create_project_dir(tmp_path_factory.mktemp("git"),
                                   Project("foorl", "9.8.7"))
    with pushd(str(foorl_dir)):
        subprocess.check_call(["git", "init"])
        subprocess.check_call(["git", "config", "user.name", "dummy"])
        subprocess.check_call(
            ["git", "config", "user.email", "*****@*****.**"])
        subprocess.check_call(["git", "add", "--all"])
        subprocess.check_call(["git", "commit", "-m", "initial commit"])
        subprocess.check_call(["git", "branch", "9.8.7"])

    # This string won't parse as a Requirement if it doesn't contain a netloc,
    # so we explicitly mention localhost.
    url_req = f"foorl@ git+file://localhost{foorl_dir.as_posix()}@9.8.7"

    rule_runner.add_to_build_file(
        "",
        dedent(f"""
            python_requirement(name="foo", requirements=["foo-bar>=0.1.2"])
            python_requirement(name="bar", requirements=["bar==5.5.5"])
            python_requirement(name="baz", requirements=["baz"])
            python_requirement(name="foorl", requirements=["{url_req}"])
            python_sources(name="util", sources=[], dependencies=[":foo", ":bar"])
            python_sources(name="app", sources=[], dependencies=[":util", ":baz", ":foorl"])
            """),
    )
    rule_runner.create_file(
        "constraints1.txt",
        dedent("""
            # Comment.
            --find-links=https://duckduckgo.com
            Foo._-BAR==1.0.0  # Inline comment.
            bar==5.5.5
            baz==2.2.2
            qux==3.4.5
            # Note that pip does not allow URL requirements in constraints files,
            # so there is no mention of foorl here.
        """),
    )

    def get_pex_request(
            constraints_file: str | None,
            resolve_all_constraints: bool | None,
            *,
            direct_deps_only: bool = False,
            additional_args: Iterable[str] = (),
            additional_lockfile_args: Iterable[str] = (),
    ) -> PexRequest:
        args = ["--backend-packages=pants.backend.python"]
        request = PexFromTargetsRequest(
            [Address("", target_name="app")],
            output_filename="demo.pex",
            internal_only=True,
            direct_deps_only=direct_deps_only,
            additional_args=additional_args,
            additional_lockfile_args=additional_lockfile_args,
        )
        if resolve_all_constraints is not None:
            args.append(
                f"--python-resolve-all-constraints={resolve_all_constraints!r}"
            )
        if constraints_file:
            args.append(f"--python-requirement-constraints={constraints_file}")
        args.append("--python-repos-indexes=[]")
        args.append(f"--python-repos-repos={find_links}")
        rule_runner.set_options(args, env_inherit={"PATH"})
        pex_request = rule_runner.request(PexRequest, [request])
        assert OrderedSet(additional_args).issubset(
            OrderedSet(pex_request.additional_args))
        return pex_request

    additional_args = ["--strip-pex-env"]
    additional_lockfile_args = ["--no-strip-pex-env"]

    pex_req1 = get_pex_request("constraints1.txt",
                               resolve_all_constraints=False)
    assert pex_req1.requirements == PexRequirements(
        ["foo-bar>=0.1.2", "bar==5.5.5", "baz", url_req],
        apply_constraints=True)

    pex_req1_direct = get_pex_request("constraints1.txt",
                                      resolve_all_constraints=False,
                                      direct_deps_only=True)
    assert pex_req1_direct.requirements == PexRequirements(
        ["baz", url_req], apply_constraints=True)

    pex_req2 = get_pex_request(
        "constraints1.txt",
        resolve_all_constraints=True,
        additional_args=additional_args,
        additional_lockfile_args=additional_lockfile_args,
    )
    pex_req2_reqs = pex_req2.requirements
    assert isinstance(pex_req2_reqs, PexRequirements)
    assert list(pex_req2_reqs.req_strings) == [
        "bar==5.5.5", "baz", "foo-bar>=0.1.2", url_req
    ]
    assert pex_req2_reqs.repository_pex is not None
    assert not info(rule_runner, pex_req2_reqs.repository_pex)["strip_pex_env"]
    repository_pex = pex_req2_reqs.repository_pex
    assert [
        "Foo._-BAR==1.0.0", "bar==5.5.5", "baz==2.2.2", "foorl", "qux==3.4.5"
    ] == requirements(rule_runner, repository_pex)

    pex_req2_direct = get_pex_request(
        "constraints1.txt",
        resolve_all_constraints=True,
        direct_deps_only=True,
        additional_args=additional_args,
        additional_lockfile_args=additional_lockfile_args,
    )
    pex_req2_reqs = pex_req2_direct.requirements
    assert isinstance(pex_req2_reqs, PexRequirements)
    assert list(pex_req2_reqs.req_strings) == ["baz", url_req]
    assert pex_req2_reqs.repository_pex == repository_pex
    assert not info(rule_runner, pex_req2_reqs.repository_pex)["strip_pex_env"]

    pex_req3_direct = get_pex_request("constraints1.txt",
                                      resolve_all_constraints=True,
                                      direct_deps_only=True)
    pex_req3_reqs = pex_req3_direct.requirements
    assert isinstance(pex_req3_reqs, PexRequirements)
    assert list(pex_req3_reqs.req_strings) == ["baz", url_req]
    assert pex_req3_reqs.repository_pex is not None
    assert pex_req3_reqs.repository_pex != repository_pex
    assert info(rule_runner, pex_req3_reqs.repository_pex)["strip_pex_env"]

    with pytest.raises(ExecutionError) as err:
        get_pex_request(None, resolve_all_constraints=True)
    assert len(err.value.wrapped_exceptions) == 1
    assert isinstance(err.value.wrapped_exceptions[0], ValueError)
    assert ("`[python].resolve_all_constraints` is enabled, so "
            "`[python].requirement_constraints` must also be set.") in str(
                err.value)

    # Shouldn't error, as we don't explicitly set --resolve-all-constraints.
    get_pex_request(None, resolve_all_constraints=None)
Пример #18
0
def setup(tmp_path_factory: TempPathFactory):
    tmp_directory = str(tmp_path_factory.getbasetemp())
    src_dir = os.path.join(os.path.dirname(__file__), 'sbroot_test')
    sbroot = os.path.join(tmp_directory, 'sbroot_test')
    shutil.copytree(src_dir, sbroot)
    yield sbroot
Пример #19
0
def patch_default_output_location(monkeypatch_session: MonkeyPatch, tmp_path_factory: TempPathFactory) -> None:
    """Patch output location by default to avoid tests writing into the repository during development."""
    tmp_path = tmp_path_factory.mktemp("test_forecasting_platform")
    print(f"Using temporary directory for default_output_location: {tmp_path}")
    monkeypatch_session.setattr(master_config, "default_output_location", tmp_path)
Пример #20
0
def _docker_registry_insecure(
    *,
    docker_client: DockerClient,
    docker_compose_insecure_list: List[Path],
    docker_services: Services,
    request,
    scale_factor: int,
    tmp_path_factory: TempPathFactory,
) -> Generator[List[DockerRegistryInsecure], None, None]:
    """Provides the endpoint of a local, mutable, insecure, docker registry."""
    cache_key = _docker_registry_insecure.__name__
    result = CACHE.get(cache_key, [])
    for i in range(scale_factor):
        if i < len(result):
            continue

        service_name = DOCKER_REGISTRY_SERVICE_PATTERN.format("insecure", i)
        tmp_path = tmp_path_factory.mktemp(__name__)

        # Create a secure registry service from the docker compose template ...
        path_docker_compose = tmp_path.joinpath(f"docker-compose-{i}.yml")
        template = Template(docker_compose_insecure_list[i].read_text("utf-8"))
        path_docker_compose.write_text(
            template.substitute({
                "CONTAINER_NAME": service_name,
                # Note: Needed to correctly populate the embedded, consolidated, service template ...
                "PATH_CERTIFICATE": "/dev/null",
                "PATH_HTPASSWD": "/dev/null",
                "PATH_KEY": "/dev/null",
            }),
            "utf-8",
        )

        LOGGER.debug("Starting insecure docker registry service [%d] ...", i)
        LOGGER.debug("  docker-compose : %s", path_docker_compose)
        LOGGER.debug("  service name   : %s", service_name)
        endpoint = start_service(
            docker_services,
            docker_compose=path_docker_compose,
            private_port=DOCKER_REGISTRY_PORT_INSECURE,
            service_name=service_name,
        )
        LOGGER.debug("Insecure docker registry endpoint [%d]: %s", i, endpoint)

        images = []
        if i == 0:
            LOGGER.debug("Replicating images into %s [%d] ...", service_name,
                         i)
            images = _replicate_images(docker_client, endpoint, request)

        result.append(
            DockerRegistryInsecure(
                docker_client=docker_client,
                docker_compose=path_docker_compose,
                endpoint=endpoint,
                endpoint_name=f"{service_name}:{DOCKER_REGISTRY_PORT_INSECURE}",
                images=images,
                service_name=service_name,
            ))
    CACHE[cache_key] = result
    yield result
Пример #21
0
 def model_storage(self, tmp_path_factory: TempPathFactory) -> ModelStorage:
     return LocalModelStorage(tmp_path_factory.mktemp(uuid.uuid4().hex))
Пример #22
0
def _docker_registry_secure(
    *,
    docker_client: DockerClient,
    docker_compose_secure_list: List[Path],
    docker_registry_auth_header_list: List[Dict[str, str]],
    docker_registry_cacerts_list: List[Path],
    docker_registry_certs_list: List[DockerRegistryCerts],
    docker_registry_htpasswd_list: List[Path],
    docker_registry_password_list: List[str],
    docker_registry_ssl_context_list: List[SSLContext],
    docker_registry_username_list: List[str],
    docker_services: Services,
    request,
    scale_factor: int,
    tmp_path_factory: TempPathFactory,
) -> Generator[List[DockerRegistrySecure], None, None]:
    """Provides the endpoint of a local, mutable, secure, docker registry."""
    cache_key = _docker_registry_secure.__name__
    result = CACHE.get(cache_key, [])
    for i in range(scale_factor):
        if i < len(result):
            continue

        service_name = DOCKER_REGISTRY_SERVICE_PATTERN.format("secure", i)
        tmp_path = tmp_path_factory.mktemp(__name__)

        # Create a secure registry service from the docker compose template ...
        path_docker_compose = tmp_path.joinpath(f"docker-compose-{i}.yml")
        template = Template(docker_compose_secure_list[i].read_text("utf-8"))
        path_docker_compose.write_text(
            template.substitute({
                "CONTAINER_NAME":
                service_name,
                "PATH_CERTIFICATE":
                docker_registry_certs_list[i].certificate,
                "PATH_HTPASSWD":
                docker_registry_htpasswd_list[i],
                "PATH_KEY":
                docker_registry_certs_list[i].private_key,
            }),
            "utf-8",
        )

        LOGGER.debug("Starting secure docker registry service [%d] ...", i)
        LOGGER.debug("  docker-compose : %s", path_docker_compose)
        LOGGER.debug("  ca certificate : %s",
                     docker_registry_certs_list[i].ca_certificate)
        LOGGER.debug("  certificate    : %s",
                     docker_registry_certs_list[i].certificate)
        LOGGER.debug("  htpasswd       : %s", docker_registry_htpasswd_list[i])
        LOGGER.debug("  private key    : %s",
                     docker_registry_certs_list[i].private_key)
        LOGGER.debug("  password       : %s", docker_registry_password_list[i])
        LOGGER.debug("  service name   : %s", service_name)
        LOGGER.debug("  username       : %s", docker_registry_username_list[i])

        check_server = partial(
            check_url_secure,
            auth_header=docker_registry_auth_header_list[i],
            ssl_context=docker_registry_ssl_context_list[i],
        )
        endpoint = start_service(
            docker_services,
            check_server=check_server,
            docker_compose=path_docker_compose,
            private_port=DOCKER_REGISTRY_PORT_SECURE,
            service_name=service_name,
        )
        LOGGER.debug("Secure docker registry endpoint [%d]: %s", i, endpoint)

        # DUCK PUNCH: Inject the secure docker registry credentials into the docker client ...
        docker_client.api._auth_configs.add_auth(  # pylint: disable=protected-access
            endpoint,
            {
                "password": docker_registry_password_list[i],
                "username": docker_registry_username_list[i],
            },
        )

        images = []
        if i == 0:
            LOGGER.debug("Replicating images into %s [%d] ...", service_name,
                         i)
            images = _replicate_images(docker_client, endpoint, request)

        result.append(
            DockerRegistrySecure(
                auth_header=docker_registry_auth_header_list[i],
                cacerts=docker_registry_cacerts_list[i],
                certs=docker_registry_certs_list[i],
                docker_client=docker_client,
                docker_compose=path_docker_compose,
                endpoint=endpoint,
                endpoint_name=f"{service_name}:{DOCKER_REGISTRY_PORT_SECURE}",
                htpasswd=docker_registry_htpasswd_list[i],
                password=docker_registry_password_list[i],
                images=images,
                service_name=service_name,
                ssl_context=docker_registry_ssl_context_list[i],
                username=docker_registry_username_list[i],
            ))
    CACHE[cache_key] = result
    yield result
Пример #23
0
def _trained_e2e_model_cache(tmp_path_factory: TempPathFactory) -> Path:
    return tmp_path_factory.mktemp("cache")
Пример #24
0
def test_empty_directories_are_equal(tmp_path_factory: TempPathFactory):
    dir1 = tmp_path_factory.mktemp("dir1")
    dir2 = tmp_path_factory.mktemp("dir2")

    assert rasa.utils.io.are_directories_equal(dir1, dir2)
Пример #25
0
def test_code_verification(tmp_path_factory: TempPathFactory,
                           model: Model) -> None:
    utils.assert_provable_code(model, Integration(),
                               tmp_path_factory.mktemp("code_verification"))
Пример #26
0
def tgt_tmp_dir(tmp_path_factory: TempPathFactory) -> Path:
    return tmp_path_factory.mktemp("tgt")
Пример #27
0
def happy_day_data(
        tmp_path_factory: TempPathFactory,
        request
):
    tmp_path = tmp_path_factory.mktemp("data")
    key_file = tmp_path / 'dev.key.pem'
    certificate_file = tmp_path / 'dev.cert.der'
    manifest_version = request.param  # Type[ManifestAsnCodecBase]
    generate_credentials(
        key_file=key_file,
        certificate_file=certificate_file,
        do_overwrite=False,
        cred_valid_time=8
    )
    fw_file = tmp_path / 'fw.bin'
    fw_file.write_bytes(os.urandom(512))

    input_cfg = {
        "manifest-version": manifest_version.get_name(),
        "vendor": {
            "domain": "arm.com",
            "custom-data-path": fw_file.as_posix()

        },
        "device": {
            "model-name": "my-device"
        },
        "priority": 15,
        "payload": {
            "url": "https://my.server.com/some.file?new=1",
            "file-path": fw_file.as_posix(),
            "format": "raw-binary"
        }
    }

    fw_version = '100.500.0'
    if 'v1' == manifest_version.get_name():
        fw_version = 0
    else:
        input_cfg['sign-image'] = True

    manifest_data = CreateAction.do_create(
        pem_key_data=key_file.read_bytes(),
        input_cfg=input_cfg,
        fw_version=fw_version,
        update_certificate=certificate_file,
        asn1_codec_class=manifest_version
    )
    manifest_file = tmp_path / 'fota_manifest.bin'
    manifest_file.write_bytes(manifest_data)

    private_key = serialization.load_pem_private_key(
        key_file.read_bytes(),
        password=None,
        backend=default_backend()
    )
    public_key = private_key.public_key()
    public_key_bytes = public_key.public_bytes(
        encoding=serialization.Encoding.X962,
        format=serialization.PublicFormat.UncompressedPoint
    )
    public_key_file = tmp_path / 'pub_key.bin'
    public_key_file.write_bytes(public_key_bytes)

    return {
        'manifest_file': manifest_file,
        'certificate_file': certificate_file,
        'pub_key_file': public_key_file,
        'priv_key_file': key_file,
        'manifest_version': manifest_version.get_name(),
    }
Пример #28
0
def tmp_root_homes_dir_init(tmp_path_factory: TempPathFactory) -> Path:
    return tmp_path_factory.mktemp("root-homes-init")
Пример #29
0
def test_loader_loads_graph_runner(
    default_model_storage: ModelStorage,
    temp_cache: TrainingCache,
    tmp_path: Path,
    tmp_path_factory: TempPathFactory,
    domain_path: Path,
):
    graph_trainer = GraphTrainer(
        model_storage=default_model_storage,
        cache=temp_cache,
        graph_runner_class=DaskGraphRunner,
    )

    test_value = "test_value"

    train_schema = GraphSchema(
        {
            "train": SchemaNode(
                needs={},
                uses=PersistableTestComponent,
                fn="train",
                constructor_name="create",
                config={"test_value": test_value},
                is_target=True,
            ),
            "load": SchemaNode(
                needs={"resource": "train"},
                uses=PersistableTestComponent,
                fn="run_inference",
                constructor_name="load",
                config={},
            ),
        }
    )
    predict_schema = GraphSchema(
        {
            "load": SchemaNode(
                needs={},
                uses=PersistableTestComponent,
                fn="run_inference",
                constructor_name="load",
                config={},
                is_target=True,
                resource=Resource("train"),
            )
        }
    )

    output_filename = tmp_path / "model.tar.gz"

    importer = TrainingDataImporter.load_from_dict(
        training_data_paths=[], domain_path=str(domain_path)
    )

    trained_at = datetime.utcnow()
    with freezegun.freeze_time(trained_at):
        model_metadata = graph_trainer.train(
            GraphModelConfiguration(
                train_schema=train_schema,
                predict_schema=predict_schema,
                training_type=TrainingType.BOTH,
                language=None,
                core_target=None,
                nlu_target=None,
            ),
            importer=importer,
            output_filename=output_filename,
        )

    assert isinstance(model_metadata, ModelMetadata)
    assert output_filename.is_file()

    loaded_model_storage_path = tmp_path_factory.mktemp("loaded model storage")

    model_metadata, loaded_predict_graph_runner = loader.load_predict_graph_runner(
        storage_path=loaded_model_storage_path,
        model_archive_path=output_filename,
        model_storage_class=LocalModelStorage,
        graph_runner_class=DaskGraphRunner,
    )

    assert loaded_predict_graph_runner.run() == {"load": test_value}

    assert model_metadata.predict_schema == predict_schema
    assert model_metadata.train_schema == train_schema
    assert model_metadata.model_id
    assert model_metadata.domain.as_dict() == Domain.from_path(domain_path).as_dict()
    assert model_metadata.rasa_open_source_version == rasa.__version__
    assert model_metadata.trained_at == trained_at
Пример #30
0
def test_constraints_validation(tmp_path_factory: TempPathFactory,
                                rule_runner: RuleRunner) -> None:
    find_links = create_dists(
        tmp_path_factory.mktemp("sdists"),
        Project("Foo-Bar", "1.0.0"),
        Project("Bar", "5.5.5"),
        Project("baz", "2.2.2"),
        Project("QUX", "3.4.5"),
    )

    rule_runner.add_to_build_file(
        "",
        dedent("""
            python_requirement_library(name="foo", requirements=["foo-bar>=0.1.2"])
            python_requirement_library(name="bar", requirements=["bar==5.5.5"])
            python_requirement_library(name="baz", requirements=["baz"])
            python_library(name="util", sources=[], dependencies=[":foo", ":bar"])
            python_library(name="app", sources=[], dependencies=[":util", ":baz"])
            """),
    )
    rule_runner.create_file(
        "constraints1.txt",
        dedent("""
            # Comment.
            --find-links=https://duckduckgo.com
            Foo._-BAR==1.0.0  # Inline comment.
            bar==5.5.5
            baz==2.2.2
            qux==3.4.5
        """),
    )

    def get_pex_request(
        constraints_file: Optional[str],
        resolve_all: Optional[ResolveAllConstraintsOption],
        *,
        direct_deps_only: bool = False,
    ) -> PexRequest:
        args = ["--backend-packages=pants.backend.python"]
        request = PexFromTargetsRequest(
            [Address("", target_name="app")],
            output_filename="demo.pex",
            internal_only=True,
            direct_deps_only=direct_deps_only,
        )
        if resolve_all:
            args.append(
                f"--python-setup-resolve-all-constraints={resolve_all.value}")
        if constraints_file:
            args.append(
                f"--python-setup-requirement-constraints={constraints_file}")
        args.append("--python-repos-indexes=[]")
        args.append(f"--python-repos-repos={find_links}")
        rule_runner.set_options(args, env_inherit={"PATH"})
        return rule_runner.request(PexRequest, [request])

    pex_req1 = get_pex_request("constraints1.txt",
                               ResolveAllConstraintsOption.NEVER)
    assert pex_req1.requirements == PexRequirements(
        ["foo-bar>=0.1.2", "bar==5.5.5", "baz"])
    assert pex_req1.repository_pex is None

    pex_req1_direct = get_pex_request("constraints1.txt",
                                      ResolveAllConstraintsOption.NEVER,
                                      direct_deps_only=True)
    assert pex_req1_direct.requirements == PexRequirements(["baz"])
    assert pex_req1_direct.repository_pex is None

    pex_req2 = get_pex_request("constraints1.txt",
                               ResolveAllConstraintsOption.ALWAYS)
    assert pex_req2.requirements == PexRequirements(
        ["foo-bar>=0.1.2", "bar==5.5.5", "baz"])
    assert pex_req2.repository_pex is not None
    repository_pex = pex_req2.repository_pex
    assert ["Foo._-BAR==1.0.0", "bar==5.5.5", "baz==2.2.2",
            "qux==3.4.5"] == requirements(rule_runner, repository_pex)

    pex_req2_direct = get_pex_request("constraints1.txt",
                                      ResolveAllConstraintsOption.ALWAYS,
                                      direct_deps_only=True)
    assert pex_req2_direct.requirements == PexRequirements(["baz"])
    assert pex_req2_direct.repository_pex == repository_pex

    with pytest.raises(ExecutionError) as err:
        get_pex_request(None, ResolveAllConstraintsOption.ALWAYS)
    assert len(err.value.wrapped_exceptions) == 1
    assert isinstance(err.value.wrapped_exceptions[0], ValueError)
    assert (
        "[python-setup].resolve_all_constraints is set to always, so "
        "either [python-setup].requirement_constraints or "
        "[python-setup].requirement_constraints_target must also be provided."
    ) in str(err.value)

    # Shouldn't error, as we don't explicitly set --resolve-all-constraints.
    get_pex_request(None, None)
Пример #31
0
def src_tmp_dir(tmp_path_factory: TempPathFactory) -> Path:
    return tmp_path_factory.mktemp("src")
Пример #32
0
    def runtest(self):
        """Run an AnyScript test item."""

        tmpdir = Path(
            TempdirFactory(TempPathFactory.from_config(self.config)).mktemp(
                self.name))

        with change_dir(tmpdir):
            self.app = AnyPyProcess(**self.app_opts)
            result = self.app.start_macro(self.macro)[0]

        # Ignore error due to missing Main.RunTest
        if "ERROR" in result:
            runtest_missing = any("Error : Main.RunTest :" in err
                                  for err in result["ERROR"])
            if runtest_missing:
                runtest_errros = (
                    "Error : Main.RunTest : Unresolved",
                    "Main.RunTest : Select Operation",
                    "Error : run : command unexpected while",
                )
                result["ERROR"][:] = [
                    err for err in result["ERROR"]
                    if not any(s in err for s in runtest_errros)
                ]
        # Check that the expected errors are present
        error_list = result.get("ERROR", [])
        if self.expect_errors:
            for xerr in self.expect_errors:
                xerr_found = False
                for error in error_list[:]:
                    if xerr in error:
                        xerr_found = True
                        error_list.remove(error)
                if not xerr_found:
                    self.errors.append("TEST ERROR: Expected error not "
                                       'found: "{}"'.format(xerr))

        # Add remaining errors to item's error list
        if error_list:
            self.errors.extend(error_list)

        # Add info to the hdf5 file if compare output was set
        if self.hdf5_outputs:
            base = Path(self.config.getoption("--anytest-output"))
            subfolder = Path(self.config.getoption("--anytest-name"))
            target = base / subfolder / self.name
            self.save_output_files(tmpdir, target, result, self.hdf5_outputs)

        if self.errors and self.config.getoption("--create-macros"):
            logfile = result["task_logfile"]
            shutil.copyfile(logfile, self.fspath / (self.name + ".txt"))
            shutil.copyfile(logfile.with_suffix(".anymcr"),
                            self.fspath / (self.name + ".anymcr"))
            macro_name = _write_macro_file(self.fspath.dirname, self.name,
                                           self.macro)

        shutil.rmtree(tmpdir, ignore_errors=True)

        if len(self.errors) > 0:
            raise AnyException(self)

        return