コード例 #1
0
def test_read_unsupported_model(
    monkeypatch: MonkeyPatch,
    tmp_path_factory: TempPathFactory,
    domain: Domain,
):
    train_model_storage = LocalModelStorage(
        tmp_path_factory.mktemp("train model storage"))
    graph_schema = GraphSchema(nodes={})

    persisted_model_dir = tmp_path_factory.mktemp("persisted models")
    archive_path = persisted_model_dir / "my-model.tar.gz"

    # Create outdated model meta data
    trained_at = datetime.utcnow()
    model_configuration = GraphModelConfiguration(graph_schema, graph_schema,
                                                  TrainingType.BOTH, None,
                                                  None, "nlu")
    outdated_model_meta_data = ModelMetadata(
        trained_at=trained_at,
        rasa_open_source_version=rasa.
        __version__,  # overwrite later to avoid error
        model_id=uuid.uuid4().hex,
        domain=domain,
        train_schema=model_configuration.train_schema,
        predict_schema=model_configuration.predict_schema,
        training_type=model_configuration.training_type,
        project_fingerprint=rasa.model.project_fingerprint(),
        language=model_configuration.language,
        core_target=model_configuration.core_target,
        nlu_target=model_configuration.nlu_target,
    )
    old_version = "0.0.1"
    outdated_model_meta_data.rasa_open_source_version = old_version

    # Package model - and inject the outdated model meta data
    monkeypatch.setattr(
        LocalModelStorage,
        "_create_model_metadata",
        lambda *args, **kwargs: outdated_model_meta_data,
    )
    train_model_storage.create_model_package(
        model_archive_path=archive_path,
        model_configuration=model_configuration,
        domain=domain,
    )

    # Unpack and inspect packaged model
    load_model_storage_dir = tmp_path_factory.mktemp("load model storage")

    expected_message = (
        f"The model version is trained using Rasa Open Source "
        f"{old_version} and is not compatible with your current "
        f"installation .*")
    with pytest.raises(UnsupportedModelVersionError, match=expected_message):
        LocalModelStorage.metadata_from_archive(archive_path)

    with pytest.raises(UnsupportedModelVersionError, match=expected_message):
        LocalModelStorage.from_model_archive(load_model_storage_dir,
                                             archive_path)
コード例 #2
0
    def _persist_metadata(
        metadata: ModelMetadata,
        temporary_directory: Path,
    ) -> None:

        rasa.shared.utils.io.dump_obj_as_json_to_file(
            temporary_directory / MODEL_ARCHIVE_METADATA_FILE,
            metadata.as_dict())
コード例 #3
0
 def _create_model_metadata(domain: Domain, predict_schema: GraphSchema,
                            train_schema: GraphSchema) -> ModelMetadata:
     return ModelMetadata(
         trained_at=datetime.utcnow(),
         rasa_open_source_version=rasa.__version__,
         model_id=uuid.uuid4().hex,
         domain=domain,
         train_schema=train_schema,
         predict_schema=predict_schema,
     )
コード例 #4
0
 def _create_model_metadata(
     domain: Domain, model_configuration: GraphModelConfiguration
 ) -> ModelMetadata:
     return ModelMetadata(
         trained_at=datetime.utcnow(),
         rasa_open_source_version=rasa.__version__,
         model_id=uuid.uuid4().hex,
         domain=domain,
         train_schema=model_configuration.train_schema,
         predict_schema=model_configuration.predict_schema,
         training_type=model_configuration.training_type,
         project_fingerprint=rasa.model.project_fingerprint(),
         language=model_configuration.language,
         core_target=model_configuration.core_target,
         nlu_target=model_configuration.nlu_target,
     )
コード例 #5
0
ファイル: test_storage.py プロジェクト: zoovu/rasa
def test_metadata_version_check():
    trained_at = datetime.utcnow()
    old_version = "2.7.2"
    expected_message = (
        f"The model version is trained using Rasa Open Source "
        f"{old_version} and is not compatible with your current "
        f"installation .*")
    with pytest.raises(UnsupportedModelVersionError, match=expected_message):
        ModelMetadata(
            trained_at,
            old_version,
            "some id",
            Domain.empty(),
            GraphSchema(nodes={}),
            GraphSchema(nodes={}),
            project_fingerprint="some_fingerprint",
            training_type=TrainingType.NLU,
            core_target="core",
            nlu_target="nlu",
            language="zh",
        )
コード例 #6
0
    def _load_metadata(directory: Path) -> ModelMetadata:
        serialized_metadata = rasa.shared.utils.io.read_json_file(
            directory / MODEL_ARCHIVE_METADATA_FILE)

        return ModelMetadata.from_dict(serialized_metadata)
コード例 #7
0
ファイル: test_storage.py プロジェクト: zoovu/rasa
def test_metadata_serialization(domain: Domain, tmp_path: Path):
    train_schema = GraphSchema({
        "train":
        SchemaNode(
            needs={},
            uses=PersistableTestComponent,
            fn="train",
            constructor_name="create",
            config={
                "some_config": 123455,
                "some more config": [{
                    "nested": "hi"
                }]
            },
        ),
        "load":
        SchemaNode(
            needs={"resource": "train"},
            uses=PersistableTestComponent,
            fn="run_inference",
            constructor_name="load",
            config={},
            is_target=True,
        ),
    })

    predict_schema = GraphSchema({
        "run":
        SchemaNode(
            needs={},
            uses=PersistableTestComponent,
            fn="run",
            constructor_name="load",
            config={
                "some_config": 123455,
                "some more config": [{
                    "nested": "hi"
                }]
            },
        ),
    })

    trained_at = datetime.utcnow()
    rasa_version = rasa.__version__
    model_id = "some unique model id"
    metadata = ModelMetadata(
        trained_at,
        rasa_version,
        model_id,
        domain,
        train_schema,
        predict_schema,
        project_fingerprint="some_fingerprint",
        training_type=TrainingType.NLU,
        core_target="core",
        nlu_target="nlu",
        language="zh",
    )

    serialized = metadata.as_dict()

    # Dump and Load to make sure it's serializable
    dump_path = tmp_path / "metadata.json"
    rasa.shared.utils.io.dump_obj_as_json_to_file(dump_path, serialized)
    loaded_serialized = rasa.shared.utils.io.read_json_file(dump_path)

    loaded_metadata = ModelMetadata.from_dict(loaded_serialized)

    assert loaded_metadata.domain.as_dict() == domain.as_dict()
    assert loaded_metadata.model_id == model_id
    assert loaded_metadata.rasa_open_source_version == rasa_version
    assert loaded_metadata.trained_at == trained_at
    assert loaded_metadata.train_schema == train_schema
    assert loaded_metadata.predict_schema == predict_schema
    assert loaded_metadata.project_fingerprint == "some_fingerprint"
    assert loaded_metadata.training_type == TrainingType.NLU
    assert loaded_metadata.core_target == "core"
    assert loaded_metadata.nlu_target == "nlu"
    assert loaded_metadata.language == "zh"