Esempio n. 1
0
def test_resource_caching(tmp_path_factory: TempPathFactory):
    model_storage = LocalModelStorage(
        tmp_path_factory.mktemp("initial_model_storage"))

    resource = Resource("my resource")

    # Fill model storage
    test_filename = "file.txt"
    test_content = "test_resource_caching"
    with model_storage.write_to(resource) as temporary_directory:
        file = temporary_directory / test_filename
        file.write_text(test_content)

    cache_dir = tmp_path_factory.mktemp("cache_dir")

    # Cache resource
    resource.to_cache(cache_dir, model_storage)

    # Reload resource from cache and inspect
    new_model_storage = LocalModelStorage(
        tmp_path_factory.mktemp("new_model_storage"))
    reinstantiated_resource = Resource.from_cache(resource.name, cache_dir,
                                                  new_model_storage)

    assert reinstantiated_resource == resource

    # Read written resource data from model storage to see whether all expected
    # contents are there
    with new_model_storage.read_from(resource) as temporary_directory:
        assert (temporary_directory /
                test_filename).read_text() == test_content
Esempio n. 2
0
def test_cached_component_replace_schema_node():
    schema_node = SchemaNode(
        needs={
            "i1": "first_input",
            "i2": "second_input"
        },
        uses=FingerprintComponent,
        fn="add",
        constructor_name="load",
        config={"a": 1},
        eager=False,
        is_input=False,
        resource=Resource("hello"),
    )

    PrecomputedValueProvider.replace_schema_node(schema_node, 2)

    assert schema_node == SchemaNode(
        needs={
            "i1": "first_input",
            "i2": "second_input"
        },
        uses=PrecomputedValueProvider,
        fn="get_value",
        constructor_name="create",
        config={"output": 2},
        eager=False,
        is_input=False,
        resource=Resource("hello"),
    )
Esempio n. 3
0
def test_resource_caching_if_already_restored(
        tmp_path_factory: TempPathFactory):
    initial_storage_dir = tmp_path_factory.mktemp("initial_model_storage")
    model_storage = LocalModelStorage(initial_storage_dir)

    resource = Resource("my resource")

    # Fill model storage
    test_filename = "file.txt"
    test_content = "test_resource_caching"
    with model_storage.write_to(resource) as temporary_directory:
        file = temporary_directory / test_filename
        file.write_text(test_content)

    cache_dir = tmp_path_factory.mktemp("cache_dir")

    # Cache resource
    resource.to_cache(cache_dir, model_storage)

    new_storage_dir = tmp_path_factory.mktemp("new dir")
    rasa.utils.common.copy_directory(initial_storage_dir, new_storage_dir)

    reinstantiated_resource = Resource.from_cache(
        resource.name,
        cache_dir,
        LocalModelStorage(new_storage_dir),
        resource.output_fingerprint,
    )

    assert reinstantiated_resource == resource
Esempio n. 4
0
def test_persist_and_load(
    training_data: TrainingData,
    config: Dict[Text, Any],
    default_model_storage: ModelStorage,
    default_execution_context: ExecutionContext,
):
    classifier = KeywordIntentClassifierGraphComponent.create(
        config,
        default_model_storage,
        Resource("keyword"),
        default_execution_context,
    )
    classifier.train(training_data)

    loaded_classifier = KeywordIntentClassifierGraphComponent.load(
        config,
        default_model_storage,
        Resource("keyword"),
        default_execution_context,
    )

    predicted = copy.copy(training_data)
    actual = copy.copy(training_data)
    loaded_messages = loaded_classifier.process(predicted.training_examples)
    trained_messages = classifier.process(actual.training_examples)
    for m1, m2 in zip(loaded_messages, trained_messages):
        assert m1.get("intent") == m2.get("intent")
def test_read_from_not_existing_resource(default_model_storage: ModelStorage):
    with default_model_storage.write_to(
            Resource("resource1")) as temporary_directory:
        file = temporary_directory / "file.txt"
        file.write_text("test")

    with pytest.raises(ValueError):
        with default_model_storage.read_from(
                Resource("a different resource")) as _:
            pass
Esempio n. 6
0
def test_train_load_predict_loop(
    default_model_storage: ModelStorage,
    default_execution_context: ExecutionContext,
    mitie_model: MitieModel,
    mitie_tokenizer: MitieTokenizer,
):
    resource = Resource("mitie_classifier")
    component = MitieIntentClassifier.create(
        MitieIntentClassifier.get_default_config(),
        default_model_storage,
        resource,
        default_execution_context,
    )

    training_data = rasa.shared.nlu.training_data.loading.load_data(
        "data/examples/rasa/demo-rasa.yml")
    # Tokenize message as classifier needs that
    mitie_tokenizer.process_training_data(training_data)

    component.train(training_data, mitie_model)

    component = MitieIntentClassifier.load(
        MitieIntentClassifier.get_default_config(),
        default_model_storage,
        resource,
        default_execution_context,
    )

    test_message = Message({TEXT: "hi"})
    mitie_tokenizer.process([test_message])
    component.process([test_message], mitie_model)

    assert test_message.data[INTENT][INTENT_NAME_KEY] == "greet"
    assert test_message.data[INTENT][PREDICTED_CONFIDENCE_KEY] > 0
def test_train_with_e2e_data(
    default_model_storage: ModelStorage,
    default_execution_context: ExecutionContext,
    tracker_events: List[List[Event]],
    skip_training: bool,
    domain: Domain,
):
    policy = UnexpecTEDIntentPolicy(
        UnexpecTEDIntentPolicy.get_default_config(),
        default_model_storage,
        Resource("UnexpecTEDIntentPolicy"),
        default_execution_context,
        featurizer=IntentMaxHistoryTrackerFeaturizer(
            IntentTokenizerSingleStateFeaturizer()
        ),
    )
    trackers_for_training = [
        TrackerWithCachedStates.from_events(
            sender_id=f"{tracker_index}", evts=events, domain=domain
        )
        for tracker_index, events in enumerate(tracker_events)
    ]
    if skip_training:
        with pytest.warns(UserWarning):
            policy.train(trackers_for_training, domain, precomputations=None)
    else:
        policy.train(trackers_for_training, domain, precomputations=None)
Esempio n. 8
0
def test_diagnostics(default_model_storage: ModelStorage,
                     default_execution_context: ExecutionContext):
    domain = Domain.from_yaml(DOMAIN_YAML)
    policy = TEDPolicy(
        TEDPolicy.get_default_config(),
        default_model_storage,
        Resource("TEDPolicy"),
        default_execution_context,
    )
    GREET_RULE = DialogueStateTracker.from_events(
        "greet rule",
        evts=[
            UserUttered(intent={"name": GREET_INTENT_NAME}),
            ActionExecuted(UTTER_GREET_ACTION),
            ActionExecuted(ACTION_LISTEN_NAME),
            UserUttered(intent={"name": GREET_INTENT_NAME}),
            ActionExecuted(ACTION_LISTEN_NAME),
        ],
    )
    precomputations = None
    policy.train([GREET_RULE], domain, precomputations)
    prediction = policy.predict_action_probabilities(
        GREET_RULE,
        domain,
        precomputations,
    )

    assert prediction.diagnostic_data
    assert "attention_weights" in prediction.diagnostic_data
    assert isinstance(prediction.diagnostic_data.get("attention_weights"),
                      np.ndarray)
Esempio n. 9
0
def test_domain_provider_provides_and_persists_domain(
    default_model_storage: ModelStorage,
    default_execution_context: ExecutionContext,
    config_path: Text,
    domain_path: Text,
    domain: Domain,
):
    resource = Resource("xy")
    component = DomainProvider.create(
        DomainProvider.get_default_config(),
        default_model_storage,
        resource,
        default_execution_context,
    )
    assert isinstance(component, DomainProvider)

    importer = TrainingDataImporter.load_from_config(config_path, domain_path)
    training_domain = component.provide_train(importer)

    assert isinstance(training_domain, Domain)
    assert domain.fingerprint() == training_domain.fingerprint()

    with default_model_storage.read_from(resource) as d:
        match = list(d.glob("**/domain.yml"))
        assert len(match) == 1
        assert match[0].is_file()
        assert domain.fingerprint() == Domain.from_path(match[0]).fingerprint()

    component_2 = DomainProvider.load(
        {}, default_model_storage, resource, default_execution_context
    )
    inference_domain = component_2.provide_inference()

    assert isinstance(inference_domain, Domain)
    assert domain.fingerprint() == inference_domain.fingerprint()
Esempio n. 10
0
def test_story_graph_provider_provide(
    default_model_storage: ModelStorage,
    default_execution_context: ExecutionContext,
    config: Dict[Text, Any],
    config_path: Text,
    domain_path: Text,
    stories_path: Text,
):
    component = StoryGraphProvider.create(
        {
            **StoryGraphProvider.get_default_config(),
            **config
        },
        default_model_storage,
        Resource("xy"),
        default_execution_context,
    )
    importer = TrainingDataImporter.load_from_config(config_path, domain_path,
                                                     [stories_path])

    story_graph_from_component = component.provide(importer)
    assert isinstance(story_graph_from_component, StoryGraph)

    story_graph = importer.get_stories(**config)

    assert story_graph.fingerprint() == story_graph_from_component.fingerprint(
    )
Esempio n. 11
0
def test_synonym_mapper_with_ints(default_model_storage: ModelStorage,
                                  default_execution_context: ExecutionContext):
    resource = Resource("xy")
    mapper = EntitySynonymMapperComponent.create({}, default_model_storage,
                                                 resource,
                                                 default_execution_context)
    entities = [{
        "start": 21,
        "end": 22,
        "text": "5",
        "value": 5,
        "confidence": 1.0,
        "additional_info": {
            "value": 5,
            "type": "value"
        },
        "entity": "number",
        "extractor": "DucklingEntityExtractorComponent",
    }]
    message = Message(data={TEXT: "He was 6 feet away", ENTITIES: entities})

    # This doesn't break
    mapper.process([message])

    assert message.get(ENTITIES) == entities
Esempio n. 12
0
def test_provide_multiproject_importer(
    default_model_storage: ModelStorage,
    default_execution_context: ExecutionContext,
) -> None:
    config_path = "data/test_multiproject/config.yml"
    importer_config = {"importers": [{"name": "MultiProjectImporter"}]}
    config = {
        "config_path": config_path,
        "config": importer_config,
        "domain_path": None,
        "training_data_paths": None,
    }
    project_provider = ProjectProvider.create(config, default_model_storage,
                                              Resource("xy"),
                                              default_execution_context)
    importer = project_provider.provide()

    training_data = importer.get_nlu_data()
    assert len(training_data.intents) == 4

    domain = importer.get_domain()
    assert len(domain.responses) == 4

    project_config = importer.get_config()
    assert len(project_config["policies"]) == 3
Esempio n. 13
0
def test_persist_and_load(
    training_data: TrainingData,
    default_sklearn_intent_classifier: SklearnIntentClassifier,
    default_model_storage: ModelStorage,
    default_execution_context: ExecutionContext,
    train_and_preprocess: Callable[..., Tuple[TrainingData, List[GraphComponent]]],
    spacy_nlp_component: SpacyNLP,
    spacy_model: SpacyModel,
):
    training_data = spacy_nlp_component.process_training_data(
        training_data, spacy_model
    )

    training_data, loaded_pipeline = train_and_preprocess(
        pipeline=[{"component": SpacyTokenizer}, {"component": SpacyFeaturizer}],
        training_data=training_data,
    )
    default_sklearn_intent_classifier.train(training_data)

    loaded = SklearnIntentClassifier.load(
        SklearnIntentClassifier.get_default_config(),
        default_model_storage,
        Resource("sklearn"),
        default_execution_context,
    )

    predicted = copy.deepcopy(training_data)
    actual = copy.deepcopy(training_data)
    loaded_messages = loaded.process(predicted.training_examples)
    trained_messages = default_sklearn_intent_classifier.process(
        actual.training_examples
    )

    for m1, m2 in zip(loaded_messages, trained_messages):
        assert m1.get("intent") == m2.get("intent")
def test_prediction_adder_add_message(
    default_model_storage: ModelStorage,
    default_execution_context: ExecutionContext,
    moodbot_domain: Domain,
    messages: List[Message],
    expected: List[UserUttered],
    input_channel: Text,
):
    component = NLUPredictionToHistoryAdder.create(
        {**NLUPredictionToHistoryAdder.get_default_config()},
        default_model_storage,
        Resource("test"),
        default_execution_context,
    )

    tracker = DialogueStateTracker("test", None)
    original_message = UserMessage(
        text="hello", input_channel=input_channel, metadata={"meta": "meta"}
    )
    tracker = component.add(messages, tracker, moodbot_domain, original_message)

    assert len(tracker.events) == len(messages)
    for i, _ in enumerate(messages):
        assert isinstance(tracker.events[i], UserUttered)
        assert tracker.events[i].text == expected[i].text
        assert tracker.events[i].intent == expected[i].intent
        assert tracker.events[i].entities == expected[i].entities
        assert tracker.events[i].input_channel == expected[i].input_channel
        assert tracker.events[i].message_id == expected[i].message_id
        assert tracker.events[i].metadata == expected[i].metadata
        assert tracker.events[i] == expected[i]
async def test_train_persist_load_with_composite_entities(
    crf_entity_extractor: Callable[[Dict[Text, Any]],
                                   CRFEntityExtractorGraphComponent],
    default_model_storage: ModelStorage,
    default_execution_context: ExecutionContext,
    whitespace_tokenizer: WhitespaceTokenizerGraphComponent,
):
    importer = RasaFileImporter(
        training_data_paths=["data/test/demo-rasa-composite-entities.yml"])
    training_data = importer.get_nlu_data()

    whitespace_tokenizer.process_training_data(training_data)

    crf_extractor = crf_entity_extractor({})
    crf_extractor.train(training_data)

    message = Message(data={TEXT: "I am looking for an italian restaurant"})

    whitespace_tokenizer.process([message])
    message2 = copy.deepcopy(message)

    processed_message = crf_extractor.process([message])[0]

    loaded_extractor = CRFEntityExtractorGraphComponent.load(
        CRFEntityExtractorGraphComponent.get_default_config(),
        default_model_storage,
        Resource("CRFEntityExtractor"),
        default_execution_context,
    )

    processed_message2 = loaded_extractor.process([message2])[0]

    assert processed_message2.fingerprint() == processed_message.fingerprint()
Esempio n. 16
0
def test_loading_from_resource_eager(default_model_storage: ModelStorage):
    previous_resource = Resource("previous resource")
    test_value = {"test": "test value"}

    # Pretend resource persisted itself before
    with default_model_storage.write_to(previous_resource) as directory:
        rasa.shared.utils.io.dump_obj_as_json_to_file(directory / "test.json",
                                                      test_value)

    node_name = "some_name"
    node = GraphNode(
        node_name=node_name,
        component_class=PersistableTestComponent,
        constructor_name="load",
        component_config={},
        fn_name="run_inference",
        inputs={},
        eager=True,
        model_storage=default_model_storage,
        # The `GraphComponent` should load from this resource
        resource=previous_resource,
        execution_context=ExecutionContext(GraphSchema({}), "123"),
    )

    actual_node_name, value = node()

    assert actual_node_name == node_name
    assert value == test_value
Esempio n. 17
0
def test_writing_to_resource_during_training(
        default_model_storage: ModelStorage):
    node_name = "some_name"

    test_value_for_sub_directory = {"test": "test value sub dir"}
    test_value = {"test dir": "test value dir"}

    node = GraphNode(
        node_name=node_name,
        component_class=PersistableTestComponent,
        constructor_name="create",
        component_config={
            "test_value": test_value,
            "test_value_for_sub_directory": test_value_for_sub_directory,
        },
        fn_name="train",
        inputs={},
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=ExecutionContext(GraphSchema({}), "123"),
    )

    _, resource = node()

    assert resource == Resource(node_name)

    with default_model_storage.read_from(resource) as directory:
        assert (rasa.shared.utils.io.read_json_file(directory /
                                                    "test.json") == test_value)
        assert (rasa.shared.utils.io.read_json_file(
            directory / "sub_dir" /
            "test.json") == test_value_for_sub_directory)
Esempio n. 18
0
def test_train_persist_nlu_data(run_in_simple_project: Callable[...,
                                                                RunResult],
                                tmp_path: Path):
    temp_dir = os.getcwd()

    run_in_simple_project(
        "train",
        "-c",
        "config.yml",
        "-d",
        "domain.yml",
        "--data",
        "data",
        "--out",
        "train_models",
        "--fixed-model-name",
        "test-model",
        "--persist-nlu-data",
    )

    models_dir = Path(temp_dir, "train_models")
    assert models_dir.is_dir()

    models = list(models_dir.glob("*"))
    assert len(models) == 1

    model = models[0]
    assert model.name == "test-model.tar.gz"

    storage, _ = LocalModelStorage.from_model_archive(tmp_path, model)

    with storage.read_from(
            Resource("nlu_training_data_provider")) as directory:
        assert (directory / DEFAULT_TRAINING_DATA_OUTPUT_PATH).exists()
Esempio n. 19
0
def test_serialize_graph_schema(tmp_path: Path):
    graph_schema = GraphSchema(
        {
            "train": SchemaNode(
                needs={},
                uses=PersistableTestComponent,
                fn="train",
                constructor_name="create",
                config={"some_config": 123455, "some more config": [{"nested": "hi"}]},
            ),
            "load": SchemaNode(
                needs={"resource": "train"},
                uses=PersistableTestComponent,
                fn="run_inference",
                constructor_name="load",
                config={},
                is_target=True,
                resource=Resource("test resource"),
            ),
        }
    )

    serialized = graph_schema.as_dict()

    # Dump it to make sure it's actually serializable
    file_path = tmp_path / "my_graph.yml"
    rasa.shared.utils.io.write_yaml(serialized, file_path)

    serialized_graph_schema_from_file = rasa.shared.utils.io.read_yaml_file(file_path)
    graph_schema_from_file = GraphSchema.from_dict(serialized_graph_schema_from_file)

    assert graph_schema_from_file == graph_schema
Esempio n. 20
0
def test_train_nlu_persist_nlu_data(run_in_simple_project: Callable[...,
                                                                    RunResult],
                                    tmp_path: Path) -> None:
    run_in_simple_project(
        "train",
        "nlu",
        "-c",
        "config.yml",
        "--nlu",
        "data/nlu.yml",
        "--out",
        "train_models",
        "--persist-nlu-data",
    )

    models_dir = Path("train_models")
    assert models_dir.is_dir()

    models = list(models_dir.glob("*"))
    assert len(models) == 1

    model = models[0]
    assert model.name.startswith("nlu-")

    storage, _ = LocalModelStorage.from_model_archive(tmp_path, model)

    with storage.read_from(
            Resource("nlu_training_data_provider")) as directory:
        assert (directory / DEFAULT_TRAINING_DATA_OUTPUT_PATH).exists()
Esempio n. 21
0
def test_train_core_with_original_or_provided_domain_and_compare(
    tmp_path_factory: TempPathFactory,
    default_model_storage: ModelStorage,
    default_execution_context: ExecutionContext,
):
    # Choose an example where the provider will remove a lot of information:
    example = Path("examples/formbot/")
    training_files = [example / "data" / "rules.yml"]

    # Choose a configuration with a policy
    # Note: This is sufficient to illustrate that the component won't be re-trained
    # when the domain changes. We do *not* test here whether removing keys would/
    # should not have any effect.
    config = """
    recipe: default.v1
    language: en

    policies:
      - name: RulePolicy
    """
    config_dir = tmp_path_factory.mktemp("config dir")
    config_file = config_dir / "config.yml"
    with open(config_file, "w") as f:
        f.write(config)

    # Train with the original domain
    original_domain_file = example / "domain.yml"
    original_output_dir = tmp_path_factory.mktemp("output dir")
    model_training.train(
        domain=original_domain_file,
        config=str(config_file),
        training_files=training_files,
        output=original_output_dir,
    )

    # Let the provider create a modified domain
    original_domain = Domain.from_file(original_domain_file)
    component = DomainForCoreTrainingProvider.create(
        {"arbitrary-unused": 234},
        default_model_storage,
        Resource("xy"),
        default_execution_context,
    )
    modified_domain = component.provide(domain=original_domain)

    # Dry-run training with the modified domain
    modified_domain_dir = tmp_path_factory.mktemp("modified domain dir")
    modified_domain_file = modified_domain_dir / "modified_config.yml"
    modified_domain.persist(modified_domain_file)

    modified_output_dir = tmp_path_factory.mktemp("modified output dir")
    modified_result = model_training.train(
        domain=modified_domain_file,
        config=str(config_file),
        training_files=training_files,
        output=modified_output_dir,
        dry_run=True,
    )

    assert modified_result.dry_run_results["train_RulePolicy0"].is_hit
Esempio n. 22
0
def test_provide(default_model_storage: ModelStorage,
                 default_execution_context: ExecutionContext):
    resource = Resource("some resource")

    domain = Domain.load("examples/rules/domain.yml")
    trackers = rasa.core.training.load_data("examples/rules/data/rules.yml",
                                            domain)

    policy = RulePolicy.create(
        RulePolicy.get_default_config(),
        default_model_storage,
        resource,
        default_execution_context,
    )

    policy.train(trackers, domain)

    provider = RuleOnlyDataProvider.load({}, default_model_storage, resource,
                                         default_execution_context)
    rule_only_data = provider.provide()

    assert rule_only_data

    for key in [RULE_ONLY_SLOTS, RULE_ONLY_LOOPS]:
        assert rule_only_data[key] == policy.lookup[key]
Esempio n. 23
0
    def from_dict(cls, serialized_graph_schema: Dict[Text,
                                                     Any]) -> GraphSchema:
        """Loads a graph schema which has been serialized using `schema.as_dict()`.

        Args:
            serialized_graph_schema: A serialized graph schema.

        Returns:
            A properly loaded schema.

        Raises:
            GraphSchemaException: In case the component class for a node couldn't be
                found.
        """
        nodes = {}
        for node_name, serialized_node in serialized_graph_schema[
                "nodes"].items():
            try:
                serialized_node[
                    "uses"] = rasa.shared.utils.common.class_from_module_path(
                        serialized_node["uses"])

                resource = serialized_node["resource"]
                if resource:
                    serialized_node["resource"] = Resource(**resource)

            except ImportError as e:
                raise GraphSchemaException(
                    "Error deserializing graph schema. Can't "
                    "find class for graph component type "
                    f"'{serialized_node['uses']}'.") from e

            nodes[node_name] = SchemaNode(**serialized_node)

        return GraphSchema(nodes)
Esempio n. 24
0
def test_entity_synonyms(default_model_storage: ModelStorage,
                         default_execution_context: ExecutionContext):
    resource = Resource("xy")
    entities = [
        {
            "entity": "test",
            "value": "chines",
            "start": 0,
            "end": 6
        },
        {
            "entity": "test",
            "value": "chinese",
            "start": 0,
            "end": 6
        },
        {
            "entity": "test",
            "value": "china",
            "start": 0,
            "end": 6
        },
    ]
    ent_synonyms = {"chines": "chinese", "NYC": "New York City"}

    mapper = EntitySynonymMapper.create({}, default_model_storage, resource,
                                        default_execution_context,
                                        ent_synonyms)
    mapper.replace_synonyms(entities)

    assert len(entities) == 3
    assert entities[0]["value"] == "chinese"
    assert entities[1]["value"] == "chinese"
    assert entities[2]["value"] == "china"
Esempio n. 25
0
def test_resource_with_model_storage(default_model_storage: ModelStorage,
                                     tmp_path: Path,
                                     temp_cache: TrainingCache):
    node_name = "some node"
    resource = Resource(node_name)
    test_filename = "persisted_model.json"
    test_content = {"epochs": 500}

    with default_model_storage.write_to(resource) as temporary_directory:
        rasa.shared.utils.io.dump_obj_as_json_to_file(
            temporary_directory / test_filename, test_content)

    test_fingerprint_key = uuid.uuid4().hex
    test_output_fingerprint_key = uuid.uuid4().hex
    temp_cache.cache_output(
        test_fingerprint_key,
        resource,
        test_output_fingerprint_key,
        default_model_storage,
    )

    new_model_storage_location = tmp_path / "new_model_storage"
    new_model_storage_location.mkdir()
    new_model_storage = LocalModelStorage(new_model_storage_location)
    restored_resource = temp_cache.get_cached_result(
        test_output_fingerprint_key, node_name, new_model_storage)

    assert isinstance(restored_resource, Resource)
    assert restored_resource == restored_resource

    with new_model_storage.read_from(restored_resource) as temporary_directory:
        cached_content = rasa.shared.utils.io.read_json_file(
            temporary_directory / test_filename)
        assert cached_content == test_content
Esempio n. 26
0
def test_write_to_and_read(default_model_storage: ModelStorage):
    test_filename = "file.txt"
    test_file_content = "hi"

    test_sub_filename = "sub_file"
    test_sub_dir_name = "sub_directory"
    test_sub_file_content = "sub file"

    resource = Resource("some_node123")

    # Fill model storage for resource
    with default_model_storage.write_to(resource) as resource_directory:
        file = resource_directory / test_filename
        file.write_text(test_file_content)

        sub_directory = resource_directory / test_sub_dir_name
        sub_directory.mkdir()
        file_in_sub_directory = sub_directory / test_sub_filename
        file_in_sub_directory.write_text(test_sub_file_content)

    # Read written resource data from model storage to see whether all expected
    # content is there
    with default_model_storage.read_from(resource) as resource_directory:
        assert (resource_directory /
                test_filename).read_text() == test_file_content
        assert (resource_directory / test_sub_dir_name /
                test_sub_filename).read_text() == test_sub_file_content
Esempio n. 27
0
def input_converter(default_model_storage: ModelStorage,
                    default_execution_context: ExecutionContext):
    return CoreFeaturizationInputConverter.create(
        CoreFeaturizationInputConverter.get_default_config(),
        default_model_storage,
        Resource("CoreFeaturizationInputConverters"),
        default_execution_context,
    )
Esempio n. 28
0
 def inner(config: Optional[Dict[Text, Any]] = None) -> CountVectorsFeaturizer:
     config = config or {}
     return CountVectorsFeaturizer.create(
         {**CountVectorsFeaturizer.get_default_config(), **config},
         default_model_storage,
         Resource("count_vectors_featurizer"),
         default_execution_context,
     )
Esempio n. 29
0
def collector(default_model_storage: ModelStorage,
              default_execution_context: ExecutionContext):
    return CoreFeaturizationCollector.create(
        CoreFeaturizationCollector.get_default_config(),
        default_model_storage,
        Resource("CoreFeaturizationCollector"),
        default_execution_context,
    )
Esempio n. 30
0
def test_graph_trainer_returns_prediction_runner(
    default_model_storage: ModelStorage,
    temp_cache: TrainingCache,
    tmp_path: Path,
    domain_path: Path,
):
    graph_trainer = GraphTrainer(
        model_storage=default_model_storage,
        cache=temp_cache,
        graph_runner_class=DaskGraphRunner,
    )

    test_value = "test_value"

    train_schema = GraphSchema({
        "train":
        SchemaNode(
            needs={},
            uses=PersistableTestComponent,
            fn="train",
            constructor_name="create",
            config={
                "test_value": test_value,
            },
            is_target=True,
        ),
        "load":
        SchemaNode(
            needs={"resource": "train"},
            uses=PersistableTestComponent,
            fn="run_inference",
            constructor_name="load",
            config={},
        ),
    })
    predict_schema = GraphSchema({
        "load":
        SchemaNode(
            needs={},
            uses=PersistableTestComponent,
            fn="run_inference",
            constructor_name="load",
            config={},
            is_target=True,
            resource=Resource("train"),
        ),
    })

    output_filename = tmp_path / "model.tar.gz"
    predict_graph_runner = graph_trainer.train(
        train_schema=train_schema,
        predict_schema=predict_schema,
        domain_path=domain_path,
        output_filename=output_filename,
    )
    assert isinstance(predict_graph_runner, DaskGraphRunner)
    assert output_filename.is_file()
    assert predict_graph_runner.run() == {"load": test_value}