예제 #1
0
파일: processor.py 프로젝트: zoovu/rasa
    def _load_model(
        model_path: Union[Text,
                          Path]) -> Tuple[Text, ModelMetadata, GraphRunner]:
        """Unpacks a model from a given path using the graph model loader."""
        try:
            if os.path.isfile(model_path):
                model_tar = model_path
            else:
                model_tar = get_latest_model(model_path)
                if not model_tar:
                    raise ModelNotFound(
                        f"No model found at path '{model_path}'.")
        except TypeError:
            raise ModelNotFound(f"Model {model_path} can not be loaded.")

        logger.info(f"Loading model {model_tar}...")
        with tempfile.TemporaryDirectory() as temporary_directory:
            try:
                metadata, runner = loader.load_predict_graph_runner(
                    Path(temporary_directory),
                    Path(model_tar),
                    LocalModelStorage,
                    DaskGraphRunner,
                )
                return os.path.basename(model_tar), metadata, runner
            except tarfile.ReadError:
                raise ModelNotFound(f"Model {model_path} can not be loaded.")
예제 #2
0
파일: test_loader.py 프로젝트: spawn08/rasa
def test_loader_loads_graph_runner(
    default_model_storage: ModelStorage,
    temp_cache: TrainingCache,
    tmp_path: Path,
    tmp_path_factory: TempPathFactory,
    domain_path: Path,
):
    graph_trainer = GraphTrainer(
        model_storage=default_model_storage,
        cache=temp_cache,
        graph_runner_class=DaskGraphRunner,
    )

    test_value = "test_value"

    train_schema = GraphSchema(
        {
            "train": SchemaNode(
                needs={},
                uses=PersistableTestComponent,
                fn="train",
                constructor_name="create",
                config={"test_value": test_value},
                is_target=True,
            ),
            "load": SchemaNode(
                needs={"resource": "train"},
                uses=PersistableTestComponent,
                fn="run_inference",
                constructor_name="load",
                config={},
            ),
        }
    )
    predict_schema = GraphSchema(
        {
            "load": SchemaNode(
                needs={},
                uses=PersistableTestComponent,
                fn="run_inference",
                constructor_name="load",
                config={},
                is_target=True,
                resource=Resource("train"),
            )
        }
    )

    output_filename = tmp_path / "model.tar.gz"

    importer = TrainingDataImporter.load_from_dict(
        training_data_paths=[], domain_path=str(domain_path)
    )

    trained_at = datetime.utcnow()
    with freezegun.freeze_time(trained_at):
        model_metadata = graph_trainer.train(
            GraphModelConfiguration(
                train_schema=train_schema,
                predict_schema=predict_schema,
                training_type=TrainingType.BOTH,
                language=None,
                core_target=None,
                nlu_target=None,
            ),
            importer=importer,
            output_filename=output_filename,
        )

    assert isinstance(model_metadata, ModelMetadata)
    assert output_filename.is_file()

    loaded_model_storage_path = tmp_path_factory.mktemp("loaded model storage")

    model_metadata, loaded_predict_graph_runner = loader.load_predict_graph_runner(
        storage_path=loaded_model_storage_path,
        model_archive_path=output_filename,
        model_storage_class=LocalModelStorage,
        graph_runner_class=DaskGraphRunner,
    )

    assert loaded_predict_graph_runner.run() == {"load": test_value}

    assert model_metadata.predict_schema == predict_schema
    assert model_metadata.train_schema == train_schema
    assert model_metadata.model_id
    assert model_metadata.domain.as_dict() == Domain.from_path(domain_path).as_dict()
    assert model_metadata.rasa_open_source_version == rasa.__version__
    assert model_metadata.trained_at == trained_at