Пример #1
0
def test_jieba_load_and_persist_dictionary(
    tmp_path_factory: TempPathFactory,
    default_model_storage: ModelStorage,
    default_execution_context: ExecutionContext,
    caplog: LogCaptureFixture,
):
    dictionary_directory = tmp_path_factory.mktemp("dictionaries")
    dictionary_path = dictionary_directory / "dictionary_1"

    dictionary_contents = """
创新办 3 i
云计算 5
凱特琳 nz
台中
        """
    dictionary_path.write_text(dictionary_contents, encoding="utf-8")

    component_config = {"dictionary_path": dictionary_directory}

    resource = Resource("jieba")
    tk = JiebaTokenizer.create(
        {
            **JiebaTokenizer.get_default_config(),
            **component_config
        },
        default_model_storage,
        resource,
        default_execution_context,
    )

    tk.process_training_data(TrainingData([Message(data={TEXT: ""})]))

    # The dictionary has not been persisted yet.
    with caplog.at_level(logging.DEBUG):
        JiebaTokenizer.load(
            {
                **JiebaTokenizer.get_default_config(),
                **component_config
            },
            default_model_storage,
            resource,
            default_execution_context,
        )
        assert any(
            "Failed to load JiebaTokenizer from model storage." in message
            for message in caplog.messages)

    tk.persist()

    # Check the persisted dictionary matches the original file.
    with default_model_storage.read_from(resource) as resource_dir:
        contents = (resource_dir / "dictionary_1").read_text(encoding="utf-8")
        assert contents == dictionary_contents

    # Delete original files to show that we read from the model storage.
    dictionary_path.unlink()
    dictionary_directory.rmdir()

    JiebaTokenizer.load(
        {
            **JiebaTokenizer.get_default_config(),
            **component_config
        },
        default_model_storage,
        resource,
        default_execution_context,
    )

    tk.process([Message(data={TEXT: ""})])
Пример #2
0
def _docker_registry_insecure(
    *,
    docker_client: DockerClient,
    docker_compose_insecure_list: List[Path],
    docker_services: Services,
    request,
    scale_factor: int,
    tmp_path_factory: TempPathFactory,
) -> Generator[List[DockerRegistryInsecure], None, None]:
    """Provides the endpoint of a local, mutable, insecure, docker registry."""
    cache_key = _docker_registry_insecure.__name__
    result = CACHE.get(cache_key, [])
    for i in range(scale_factor):
        if i < len(result):
            continue

        service_name = DOCKER_REGISTRY_SERVICE_PATTERN.format("insecure", i)
        tmp_path = tmp_path_factory.mktemp(__name__)

        # Create a secure registry service from the docker compose template ...
        path_docker_compose = tmp_path.joinpath(f"docker-compose-{i}.yml")
        template = Template(docker_compose_insecure_list[i].read_text("utf-8"))
        path_docker_compose.write_text(
            template.substitute({
                "CONTAINER_NAME": service_name,
                # Note: Needed to correctly populate the embedded, consolidated, service template ...
                "PATH_CERTIFICATE": "/dev/null",
                "PATH_HTPASSWD": "/dev/null",
                "PATH_KEY": "/dev/null",
            }),
            "utf-8",
        )

        LOGGER.debug("Starting insecure docker registry service [%d] ...", i)
        LOGGER.debug("  docker-compose : %s", path_docker_compose)
        LOGGER.debug("  service name   : %s", service_name)
        endpoint = start_service(
            docker_services,
            docker_compose=path_docker_compose,
            private_port=DOCKER_REGISTRY_PORT_INSECURE,
            service_name=service_name,
        )
        LOGGER.debug("Insecure docker registry endpoint [%d]: %s", i, endpoint)

        images = []
        if i == 0:
            LOGGER.debug("Replicating images into %s [%d] ...", service_name,
                         i)
            images = _replicate_images(docker_client, endpoint, request)

        result.append(
            DockerRegistryInsecure(
                docker_client=docker_client,
                docker_compose=path_docker_compose,
                endpoint=endpoint,
                endpoint_name=f"{service_name}:{DOCKER_REGISTRY_PORT_INSECURE}",
                images=images,
                service_name=service_name,
            ))
    CACHE[cache_key] = result
    yield result
Пример #3
0
def _docker_registry_secure(
    *,
    docker_client: DockerClient,
    docker_compose_secure_list: List[Path],
    docker_registry_auth_header_list: List[Dict[str, str]],
    docker_registry_cacerts_list: List[Path],
    docker_registry_certs_list: List[DockerRegistryCerts],
    docker_registry_htpasswd_list: List[Path],
    docker_registry_password_list: List[str],
    docker_registry_ssl_context_list: List[SSLContext],
    docker_registry_username_list: List[str],
    docker_services: Services,
    request,
    scale_factor: int,
    tmp_path_factory: TempPathFactory,
) -> Generator[List[DockerRegistrySecure], None, None]:
    """Provides the endpoint of a local, mutable, secure, docker registry."""
    cache_key = _docker_registry_secure.__name__
    result = CACHE.get(cache_key, [])
    for i in range(scale_factor):
        if i < len(result):
            continue

        service_name = DOCKER_REGISTRY_SERVICE_PATTERN.format("secure", i)
        tmp_path = tmp_path_factory.mktemp(__name__)

        # Create a secure registry service from the docker compose template ...
        path_docker_compose = tmp_path.joinpath(f"docker-compose-{i}.yml")
        template = Template(docker_compose_secure_list[i].read_text("utf-8"))
        path_docker_compose.write_text(
            template.substitute({
                "CONTAINER_NAME":
                service_name,
                "PATH_CERTIFICATE":
                docker_registry_certs_list[i].certificate,
                "PATH_HTPASSWD":
                docker_registry_htpasswd_list[i],
                "PATH_KEY":
                docker_registry_certs_list[i].private_key,
            }),
            "utf-8",
        )

        LOGGER.debug("Starting secure docker registry service [%d] ...", i)
        LOGGER.debug("  docker-compose : %s", path_docker_compose)
        LOGGER.debug("  ca certificate : %s",
                     docker_registry_certs_list[i].ca_certificate)
        LOGGER.debug("  certificate    : %s",
                     docker_registry_certs_list[i].certificate)
        LOGGER.debug("  htpasswd       : %s", docker_registry_htpasswd_list[i])
        LOGGER.debug("  private key    : %s",
                     docker_registry_certs_list[i].private_key)
        LOGGER.debug("  password       : %s", docker_registry_password_list[i])
        LOGGER.debug("  service name   : %s", service_name)
        LOGGER.debug("  username       : %s", docker_registry_username_list[i])

        check_server = partial(
            check_url_secure,
            auth_header=docker_registry_auth_header_list[i],
            ssl_context=docker_registry_ssl_context_list[i],
        )
        endpoint = start_service(
            docker_services,
            check_server=check_server,
            docker_compose=path_docker_compose,
            private_port=DOCKER_REGISTRY_PORT_SECURE,
            service_name=service_name,
        )
        LOGGER.debug("Secure docker registry endpoint [%d]: %s", i, endpoint)

        # DUCK PUNCH: Inject the secure docker registry credentials into the docker client ...
        docker_client.api._auth_configs.add_auth(  # pylint: disable=protected-access
            endpoint,
            {
                "password": docker_registry_password_list[i],
                "username": docker_registry_username_list[i],
            },
        )

        images = []
        if i == 0:
            LOGGER.debug("Replicating images into %s [%d] ...", service_name,
                         i)
            images = _replicate_images(docker_client, endpoint, request)

        result.append(
            DockerRegistrySecure(
                auth_header=docker_registry_auth_header_list[i],
                cacerts=docker_registry_cacerts_list[i],
                certs=docker_registry_certs_list[i],
                docker_client=docker_client,
                docker_compose=path_docker_compose,
                endpoint=endpoint,
                endpoint_name=f"{service_name}:{DOCKER_REGISTRY_PORT_SECURE}",
                htpasswd=docker_registry_htpasswd_list[i],
                password=docker_registry_password_list[i],
                images=images,
                service_name=service_name,
                ssl_context=docker_registry_ssl_context_list[i],
                username=docker_registry_username_list[i],
            ))
    CACHE[cache_key] = result
    yield result
Пример #4
0
def tmp_root_homes_dir_init(tmp_path_factory: TempPathFactory) -> Path:
    return tmp_path_factory.mktemp("root-homes-init")
Пример #5
0
 def model_storage(self, tmp_path_factory: TempPathFactory) -> ModelStorage:
     return LocalModelStorage(tmp_path_factory.mktemp(uuid.uuid4().hex))
Пример #6
0
def test_create_model_package(tmp_path_factory: TempPathFactory,
                              domain: Domain):
    train_model_storage = LocalModelStorage(
        tmp_path_factory.mktemp("train model storage"))

    train_schema = GraphSchema({
        "train":
        SchemaNode(
            needs={},
            uses=PersistableTestComponent,
            fn="train",
            constructor_name="create",
            config={
                "some_config": 123455,
                "some more config": [{
                    "nested": "hi"
                }]
            },
        ),
        "load":
        SchemaNode(
            needs={"resource": "train"},
            uses=PersistableTestComponent,
            fn="run_inference",
            constructor_name="load",
            config={},
            is_target=True,
        ),
    })

    predict_schema = GraphSchema({
        "run":
        SchemaNode(
            needs={},
            uses=PersistableTestComponent,
            fn="run",
            constructor_name="load",
            config={
                "some_config": 123455,
                "some more config": [{
                    "nested": "hi"
                }]
            },
        )
    })

    # Fill model Storage
    with train_model_storage.write_to(Resource("resource1")) as directory:
        file = directory / "file.txt"
        file.write_text("test")

    # Package model
    persisted_model_dir = tmp_path_factory.mktemp("persisted models")
    archive_path = persisted_model_dir / "my-model.tar.gz"

    trained_at = datetime.utcnow()
    with freezegun.freeze_time(trained_at):
        train_model_storage.create_model_package(
            archive_path,
            GraphModelConfiguration(train_schema, predict_schema,
                                    TrainingType.BOTH, None, None, "nlu"),
            domain,
        )

    # Unpack and inspect packaged model
    load_model_storage_dir = tmp_path_factory.mktemp("load model storage")

    just_packaged_metadata = LocalModelStorage.metadata_from_archive(
        archive_path)

    (load_model_storage,
     packaged_metadata) = LocalModelStorage.from_model_archive(
         load_model_storage_dir, archive_path)

    assert just_packaged_metadata.trained_at == packaged_metadata.trained_at

    assert packaged_metadata.train_schema == train_schema
    assert packaged_metadata.predict_schema == predict_schema
    assert packaged_metadata.domain.as_dict() == domain.as_dict()

    assert packaged_metadata.rasa_open_source_version == rasa.__version__
    assert packaged_metadata.trained_at == trained_at
    assert packaged_metadata.model_id
    assert packaged_metadata.project_fingerprint

    persisted_resources = load_model_storage_dir.glob("*")
    assert list(persisted_resources) == [
        Path(load_model_storage_dir, "resource1")
    ]
Пример #7
0
def test_constraints_validation(tmp_path_factory: TempPathFactory,
                                rule_runner: RuleRunner) -> None:
    find_links = create_dists(
        tmp_path_factory.mktemp("sdists"),
        Project("Foo-Bar", "1.0.0"),
        Project("Bar", "5.5.5"),
        Project("baz", "2.2.2"),
        Project("QUX", "3.4.5"),
    )

    # Turn the project dir into a git repo, so it can be cloned.
    foorl_dir = create_project_dir(tmp_path_factory.mktemp("git"),
                                   Project("foorl", "9.8.7"))
    with pushd(str(foorl_dir)):
        subprocess.check_call(["git", "init"])
        subprocess.check_call(["git", "config", "user.name", "dummy"])
        subprocess.check_call(
            ["git", "config", "user.email", "*****@*****.**"])
        subprocess.check_call(["git", "add", "--all"])
        subprocess.check_call(["git", "commit", "-m", "initial commit"])
        subprocess.check_call(["git", "branch", "9.8.7"])

    # This string won't parse as a Requirement if it doesn't contain a netloc,
    # so we explicitly mention localhost.
    url_req = f"foorl@ git+file://localhost{foorl_dir.as_posix()}@9.8.7"

    rule_runner.add_to_build_file(
        "",
        dedent(f"""
            python_requirement(name="foo", requirements=["foo-bar>=0.1.2"])
            python_requirement(name="bar", requirements=["bar==5.5.5"])
            python_requirement(name="baz", requirements=["baz"])
            python_requirement(name="foorl", requirements=["{url_req}"])
            python_sources(name="util", sources=[], dependencies=[":foo", ":bar"])
            python_sources(name="app", sources=[], dependencies=[":util", ":baz", ":foorl"])
            """),
    )
    rule_runner.create_file(
        "constraints1.txt",
        dedent("""
            # Comment.
            --find-links=https://duckduckgo.com
            Foo._-BAR==1.0.0  # Inline comment.
            bar==5.5.5
            baz==2.2.2
            qux==3.4.5
            # Note that pip does not allow URL requirements in constraints files,
            # so there is no mention of foorl here.
        """),
    )

    def get_pex_request(
            constraints_file: str | None,
            resolve_all_constraints: bool | None,
            *,
            direct_deps_only: bool = False,
            additional_args: Iterable[str] = (),
            additional_lockfile_args: Iterable[str] = (),
    ) -> PexRequest:
        args = ["--backend-packages=pants.backend.python"]
        request = PexFromTargetsRequest(
            [Address("", target_name="app")],
            output_filename="demo.pex",
            internal_only=True,
            direct_deps_only=direct_deps_only,
            additional_args=additional_args,
            additional_lockfile_args=additional_lockfile_args,
        )
        if resolve_all_constraints is not None:
            args.append(
                f"--python-resolve-all-constraints={resolve_all_constraints!r}"
            )
        if constraints_file:
            args.append(f"--python-requirement-constraints={constraints_file}")
        args.append("--python-repos-indexes=[]")
        args.append(f"--python-repos-repos={find_links}")
        rule_runner.set_options(args, env_inherit={"PATH"})
        pex_request = rule_runner.request(PexRequest, [request])
        assert OrderedSet(additional_args).issubset(
            OrderedSet(pex_request.additional_args))
        return pex_request

    additional_args = ["--strip-pex-env"]
    additional_lockfile_args = ["--no-strip-pex-env"]

    pex_req1 = get_pex_request("constraints1.txt",
                               resolve_all_constraints=False)
    assert pex_req1.requirements == PexRequirements(
        ["foo-bar>=0.1.2", "bar==5.5.5", "baz", url_req],
        apply_constraints=True)

    pex_req1_direct = get_pex_request("constraints1.txt",
                                      resolve_all_constraints=False,
                                      direct_deps_only=True)
    assert pex_req1_direct.requirements == PexRequirements(
        ["baz", url_req], apply_constraints=True)

    pex_req2 = get_pex_request(
        "constraints1.txt",
        resolve_all_constraints=True,
        additional_args=additional_args,
        additional_lockfile_args=additional_lockfile_args,
    )
    pex_req2_reqs = pex_req2.requirements
    assert isinstance(pex_req2_reqs, PexRequirements)
    assert list(pex_req2_reqs.req_strings) == [
        "bar==5.5.5", "baz", "foo-bar>=0.1.2", url_req
    ]
    assert pex_req2_reqs.repository_pex is not None
    assert not info(rule_runner, pex_req2_reqs.repository_pex)["strip_pex_env"]
    repository_pex = pex_req2_reqs.repository_pex
    assert [
        "Foo._-BAR==1.0.0", "bar==5.5.5", "baz==2.2.2", "foorl", "qux==3.4.5"
    ] == requirements(rule_runner, repository_pex)

    pex_req2_direct = get_pex_request(
        "constraints1.txt",
        resolve_all_constraints=True,
        direct_deps_only=True,
        additional_args=additional_args,
        additional_lockfile_args=additional_lockfile_args,
    )
    pex_req2_reqs = pex_req2_direct.requirements
    assert isinstance(pex_req2_reqs, PexRequirements)
    assert list(pex_req2_reqs.req_strings) == ["baz", url_req]
    assert pex_req2_reqs.repository_pex == repository_pex
    assert not info(rule_runner, pex_req2_reqs.repository_pex)["strip_pex_env"]

    pex_req3_direct = get_pex_request("constraints1.txt",
                                      resolve_all_constraints=True,
                                      direct_deps_only=True)
    pex_req3_reqs = pex_req3_direct.requirements
    assert isinstance(pex_req3_reqs, PexRequirements)
    assert list(pex_req3_reqs.req_strings) == ["baz", url_req]
    assert pex_req3_reqs.repository_pex is not None
    assert pex_req3_reqs.repository_pex != repository_pex
    assert info(rule_runner, pex_req3_reqs.repository_pex)["strip_pex_env"]

    with pytest.raises(ExecutionError) as err:
        get_pex_request(None, resolve_all_constraints=True)
    assert len(err.value.wrapped_exceptions) == 1
    assert isinstance(err.value.wrapped_exceptions[0], ValueError)
    assert ("`[python].resolve_all_constraints` is enabled, so "
            "`[python].requirement_constraints` must also be set.") in str(
                err.value)

    # Shouldn't error, as we don't explicitly set --resolve-all-constraints.
    get_pex_request(None, resolve_all_constraints=None)
Пример #8
0
def tmp_root_homes_dir(tmp_path_factory: TempPathFactory) -> Path:
    return tmp_path_factory.mktemp("root_homes")
Пример #9
0
def tmp_user_home_dir(tmp_path_factory: TempPathFactory) -> Path:
    return tmp_path_factory.mktemp("home_user")
Пример #10
0
def app_module_tmp_path(tmp_path_factory: TempPathFactory) -> Path:
    """Copy app modules to temporary directory to test custom app module paths."""
    tmp_dir = tmp_path_factory.mktemp("app")
    shutil.copytree(
        Path(pre_start_module.__file__).parent, Path(f"{tmp_dir}/tmp_app"))
    return tmp_dir
Пример #11
0
def gunicorn_conf_tmp_path(tmp_path_factory: TempPathFactory) -> Path:
    """Create temporary directory for Gunicorn configuration file."""
    return tmp_path_factory.mktemp("gunicorn")
Пример #12
0
def logging_conf_tmp_file_path(tmp_path_factory: TempPathFactory) -> Path:
    """Copy logging configuration module to custom temporary location."""
    tmp_dir = tmp_path_factory.mktemp("tmp_log")
    shutil.copy(Path(logging_conf_module.__file__),
                Path(f"{tmp_dir}/tmp_log.py"))
    return tmp_dir
Пример #13
0
def tgt_pgp_tmp_dir(tmp_path_factory: TempPathFactory) -> Path:
    return tmp_path_factory.mktemp("tgt_pgp")
Пример #14
0
def src_pgp_tmp_dir(tmp_path_factory: TempPathFactory) -> Path:
    return tmp_path_factory.mktemp("src-pgp")
Пример #15
0
def test_constraints_validation(tmp_path_factory: TempPathFactory,
                                rule_runner: RuleRunner) -> None:
    find_links = create_dists(
        tmp_path_factory.mktemp("sdists"),
        Project("Foo-Bar", "1.0.0"),
        Project("Bar", "5.5.5"),
        Project("baz", "2.2.2"),
        Project("QUX", "3.4.5"),
    )

    rule_runner.add_to_build_file(
        "",
        dedent("""
            python_requirement_library(name="foo", requirements=["foo-bar>=0.1.2"])
            python_requirement_library(name="bar", requirements=["bar==5.5.5"])
            python_requirement_library(name="baz", requirements=["baz"])
            python_library(name="util", sources=[], dependencies=[":foo", ":bar"])
            python_library(name="app", sources=[], dependencies=[":util", ":baz"])
            """),
    )
    rule_runner.create_file(
        "constraints1.txt",
        dedent("""
            # Comment.
            --find-links=https://duckduckgo.com
            Foo._-BAR==1.0.0  # Inline comment.
            bar==5.5.5
            baz==2.2.2
            qux==3.4.5
        """),
    )

    def get_pex_request(
        constraints_file: Optional[str],
        resolve_all: Optional[ResolveAllConstraintsOption],
        *,
        direct_deps_only: bool = False,
    ) -> PexRequest:
        args = ["--backend-packages=pants.backend.python"]
        request = PexFromTargetsRequest(
            [Address("", target_name="app")],
            output_filename="demo.pex",
            internal_only=True,
            direct_deps_only=direct_deps_only,
        )
        if resolve_all:
            args.append(
                f"--python-setup-resolve-all-constraints={resolve_all.value}")
        if constraints_file:
            args.append(
                f"--python-setup-requirement-constraints={constraints_file}")
        args.append("--python-repos-indexes=[]")
        args.append(f"--python-repos-repos={find_links}")
        rule_runner.set_options(args, env_inherit={"PATH"})
        return rule_runner.request(PexRequest, [request])

    pex_req1 = get_pex_request("constraints1.txt",
                               ResolveAllConstraintsOption.NEVER)
    assert pex_req1.requirements == PexRequirements(
        ["foo-bar>=0.1.2", "bar==5.5.5", "baz"])
    assert pex_req1.repository_pex is None

    pex_req1_direct = get_pex_request("constraints1.txt",
                                      ResolveAllConstraintsOption.NEVER,
                                      direct_deps_only=True)
    assert pex_req1_direct.requirements == PexRequirements(["baz"])
    assert pex_req1_direct.repository_pex is None

    pex_req2 = get_pex_request("constraints1.txt",
                               ResolveAllConstraintsOption.ALWAYS)
    assert pex_req2.requirements == PexRequirements(
        ["foo-bar>=0.1.2", "bar==5.5.5", "baz"])
    assert pex_req2.repository_pex is not None
    repository_pex = pex_req2.repository_pex
    assert ["Foo._-BAR==1.0.0", "bar==5.5.5", "baz==2.2.2",
            "qux==3.4.5"] == requirements(rule_runner, repository_pex)

    pex_req2_direct = get_pex_request("constraints1.txt",
                                      ResolveAllConstraintsOption.ALWAYS,
                                      direct_deps_only=True)
    assert pex_req2_direct.requirements == PexRequirements(["baz"])
    assert pex_req2_direct.repository_pex == repository_pex

    with pytest.raises(ExecutionError) as err:
        get_pex_request(None, ResolveAllConstraintsOption.ALWAYS)
    assert len(err.value.wrapped_exceptions) == 1
    assert isinstance(err.value.wrapped_exceptions[0], ValueError)
    assert (
        "[python-setup].resolve_all_constraints is set to always, so "
        "either [python-setup].requirement_constraints or "
        "[python-setup].requirement_constraints_target must also be provided."
    ) in str(err.value)

    # Shouldn't error, as we don't explicitly set --resolve-all-constraints.
    get_pex_request(None, None)
Пример #16
0
def tmp_export_dir(tmp_path_factory: TempPathFactory) -> Path:
    return tmp_path_factory.mktemp("export_dir")
Пример #17
0
def tmp_enc_dec_dir(tmp_path_factory: TempPathFactory) -> Path:
    return tmp_path_factory.mktemp("enc_dec_dir")
Пример #18
0
def _mk_gpg_ctx_w_info_fixture_no_checks(
        gen_fn: _GpgCtxGenFnT, tmp_factory: TempPathFactory,
        request: OptPyTestFixtureRequestT) -> GpgContextWGenInfo:
    home_dir = tmp_factory.mktemp("home_user")
    gpg_ctx = gen_fn(home_dir, request)
    return gpg_ctx
Пример #19
0
def test_loader_loads_graph_runner(
    default_model_storage: ModelStorage,
    temp_cache: TrainingCache,
    tmp_path: Path,
    tmp_path_factory: TempPathFactory,
    domain_path: Path,
):
    graph_trainer = GraphTrainer(
        model_storage=default_model_storage,
        cache=temp_cache,
        graph_runner_class=DaskGraphRunner,
    )

    test_value = "test_value"

    train_schema = GraphSchema({
        "train":
        SchemaNode(
            needs={},
            uses=PersistableTestComponent,
            fn="train",
            constructor_name="create",
            config={
                "test_value": test_value,
            },
            is_target=True,
        ),
        "load":
        SchemaNode(
            needs={"resource": "train"},
            uses=PersistableTestComponent,
            fn="run_inference",
            constructor_name="load",
            config={},
        ),
    })
    predict_schema = GraphSchema({
        "load":
        SchemaNode(
            needs={},
            uses=PersistableTestComponent,
            fn="run_inference",
            constructor_name="load",
            config={},
            is_target=True,
            resource=Resource("train"),
        ),
    })

    output_filename = tmp_path / "model.tar.gz"

    trained_at = datetime.utcnow()
    with freezegun.freeze_time(trained_at):
        predict_graph_runner = graph_trainer.train(
            train_schema=train_schema,
            predict_schema=predict_schema,
            domain_path=domain_path,
            output_filename=output_filename,
        )

    assert isinstance(predict_graph_runner, DaskGraphRunner)
    assert output_filename.is_file()
    assert predict_graph_runner.run() == {"load": test_value}

    loaded_model_storage_path = tmp_path_factory.mktemp("loaded model storage")

    model_metadata, loaded_predict_graph_runner = loader.load_predict_graph_runner(
        storage_path=loaded_model_storage_path,
        model_archive_path=output_filename,
        model_storage_class=LocalModelStorage,
        graph_runner_class=DaskGraphRunner,
    )

    assert loaded_predict_graph_runner.run() == {"load": test_value}

    assert model_metadata.predict_schema == predict_schema
    assert model_metadata.train_schema == train_schema
    assert model_metadata.model_id
    assert model_metadata.domain.as_dict() == Domain.from_path(
        domain_path).as_dict()
    assert model_metadata.rasa_open_source_version == rasa.__version__
    assert model_metadata.trained_at == trained_at
Пример #20
0
def temp_dir(tmp_path_factory: TempPathFactory) -> Path:
    return tmp_path_factory.mktemp("test-tmp")
Пример #21
0
def _trained_e2e_model_cache(tmp_path_factory: TempPathFactory) -> Path:
    return tmp_path_factory.mktemp("cache")
Пример #22
0
def tmp_root_homes_dir_enc_dec_ro(tmp_path_factory: TempPathFactory) -> Path:
    return tmp_path_factory.mktemp("root-homes-enc-dec-ro")