Beispiel #1
0
async def test_process_gives_diagnostic_data(
    create_train_load_and_process_diet: Callable[..., Message],
    default_execution_context: ExecutionContext,
    should_add_diagnostic_data: bool,
):
    default_execution_context.should_add_diagnostic_data = should_add_diagnostic_data
    default_execution_context.node_name = "DIETClassifier_node_name"
    processed_message = create_train_load_and_process_diet({EPOCHS: 1})

    if should_add_diagnostic_data:
        # Tests if processing a message returns attention weights as numpy array.
        diagnostic_data = processed_message.get(DIAGNOSTIC_DATA)

        # DIETClassifier should add attention weights
        name = "DIETClassifier_node_name"
        assert isinstance(diagnostic_data, dict)
        assert name in diagnostic_data
        assert "attention_weights" in diagnostic_data[name]
        assert isinstance(diagnostic_data[name].get("attention_weights"),
                          np.ndarray)
        assert "text_transformed" in diagnostic_data[name]
        assert isinstance(diagnostic_data[name].get("text_transformed"),
                          np.ndarray)
    else:
        assert DIAGNOSTIC_DATA not in processed_message.data
Beispiel #2
0
    def train(
        self,
        train_schema: GraphSchema,
        predict_schema: GraphSchema,
        domain_path: Path,
        output_filename: Path,
    ) -> GraphRunner:
        """Trains and packages a model and returns the prediction graph runner.

        Args:
            train_schema: The train graph schema.
            predict_schema: The predict graph schema.
            domain_path: The path to the domain file.
            output_filename: The location to save the packaged model.

        Returns:
            A graph runner loaded with the predict schema.

        """
        logger.debug("Starting training.")

        pruned_training_schema = self._fingerprint_and_prune(train_schema)

        hooks = [
            TrainingHook(cache=self._cache, model_storage=self._model_storage)
        ]

        graph_runner = self._graph_runner_class.create(
            graph_schema=pruned_training_schema,
            model_storage=self._model_storage,
            execution_context=ExecutionContext(
                graph_schema=pruned_training_schema),
            hooks=hooks,
        )

        logger.debug(
            "Running the pruned train graph with real node execution.")

        graph_runner.run()

        domain = Domain.from_path(domain_path)
        model_metadata = self._model_storage.create_model_package(
            output_filename, train_schema, predict_schema, domain)

        return self._graph_runner_class.create(
            graph_schema=predict_schema,
            model_storage=self._model_storage,
            execution_context=ExecutionContext(
                graph_schema=predict_schema, model_id=model_metadata.model_id),
        )
Beispiel #3
0
def test_unused_node(default_model_storage: ModelStorage):
    graph_schema = GraphSchema({
        "provide":
        SchemaNode(
            needs={},
            uses=ProvideX,
            fn="provide",
            constructor_name="create",
            config={},
            is_target=True,
        ),
        # This node will not fail as it will be pruned because it is not a target
        # or a target's ancestor.
        "assert_false":
        SchemaNode(
            needs={"i": "input"},
            uses=AssertComponent,
            fn="run_assert",
            constructor_name="create",
            config={"value_to_assert": "some_value"},
        ),
    })
    runner = DaskGraphRunner(
        graph_schema=graph_schema,
        model_storage=default_model_storage,
        execution_context=ExecutionContext(graph_schema=graph_schema,
                                           model_id="1"),
    )
    results = runner.run(inputs={"input": "some_other_value"})
    assert results == {"provide": 1}
Beispiel #4
0
def test_fn_exception(default_model_storage: ModelStorage):
    class BadFn(GraphComponent):
        @classmethod
        def create(
            cls,
            config: Dict[Text, Any],
            model_storage: ModelStorage,
            resource: Resource,
            execution_context: ExecutionContext,
        ) -> BadFn:
            return cls()

        def run(self) -> None:
            raise ValueError("Oh no!")

    node = GraphNode(
        node_name="bad_fn",
        component_class=BadFn,
        constructor_name="create",
        component_config={},
        fn_name="run",
        inputs={},
        eager=True,
        model_storage=default_model_storage,
        resource=None,
        execution_context=ExecutionContext(GraphSchema({}), "some_id"),
    )

    with pytest.raises(GraphComponentException):
        node()
Beispiel #5
0
def test_writing_to_resource_during_training(
        default_model_storage: ModelStorage):
    node_name = "some_name"

    test_value_for_sub_directory = {"test": "test value sub dir"}
    test_value = {"test dir": "test value dir"}

    node = GraphNode(
        node_name=node_name,
        component_class=PersistableTestComponent,
        constructor_name="create",
        component_config={
            "test_value": test_value,
            "test_value_for_sub_directory": test_value_for_sub_directory,
        },
        fn_name="train",
        inputs={},
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=ExecutionContext(GraphSchema({}), "123"),
    )

    _, resource = node()

    assert resource == Resource(node_name)

    with default_model_storage.read_from(resource) as directory:
        assert (rasa.shared.utils.io.read_json_file(directory /
                                                    "test.json") == test_value)
        assert (rasa.shared.utils.io.read_json_file(
            directory / "sub_dir" /
            "test.json") == test_value_for_sub_directory)
Beispiel #6
0
    def fingerprint(
        self,
        train_schema: GraphSchema,
        importer: TrainingDataImporter,
        is_finetuning: bool = False,
    ) -> Dict[Text, Union[FingerprintStatus, Any]]:
        """Runs the graph using fingerprints to determine which nodes need to re-run.

        Nodes which have a matching fingerprint key in the cache can either be removed
        entirely from the graph, or replaced with a cached value if their output is
        needed by descendent nodes.

        Args:
            train_schema: The train graph schema that will be run in fingerprint mode.
            importer: The importer which provides the training data for the training.
            is_finetuning: `True` if we want to finetune the model.

        Returns:
            Mapping of node names to fingerprint results.
        """
        fingerprint_schema = self._create_fingerprint_schema(train_schema)

        fingerprint_graph_runner = self._graph_runner_class.create(
            graph_schema=fingerprint_schema,
            model_storage=self._model_storage,
            execution_context=ExecutionContext(graph_schema=train_schema,
                                               is_finetuning=is_finetuning),
        )

        logger.debug("Running the train graph in fingerprint mode.")
        return fingerprint_graph_runner.run(
            inputs={PLACEHOLDER_IMPORTER: importer})
Beispiel #7
0
def test_loading_from_resource_eager(default_model_storage: ModelStorage):
    previous_resource = Resource("previous resource")
    test_value = {"test": "test value"}

    # Pretend resource persisted itself before
    with default_model_storage.write_to(previous_resource) as directory:
        rasa.shared.utils.io.dump_obj_as_json_to_file(directory / "test.json",
                                                      test_value)

    node_name = "some_name"
    node = GraphNode(
        node_name=node_name,
        component_class=PersistableTestComponent,
        constructor_name="load",
        component_config={},
        fn_name="run_inference",
        inputs={},
        eager=True,
        model_storage=default_model_storage,
        # The `GraphComponent` should load from this resource
        resource=previous_resource,
        execution_context=ExecutionContext(GraphSchema({}), "123"),
    )

    actual_node_name, value = node()

    assert actual_node_name == node_name
    assert value == test_value
Beispiel #8
0
def test_target_override(eager: bool, default_model_storage: ModelStorage):
    graph_schema = GraphSchema(
        {
            "add": SchemaNode(
                needs={"i1": "first_input", "i2": "second_input"},
                uses=AddInputs,
                fn="add",
                constructor_name="create",
                config={},
                eager=eager,
            ),
            "subtract_2": SchemaNode(
                needs={"i": "add"},
                uses=SubtractByX,
                fn="subtract_x",
                constructor_name="create",
                config={"x": 3},
                eager=eager,
                is_target=True,
            ),
        }
    )

    execution_context = ExecutionContext(graph_schema=graph_schema, model_id="1")

    runner = DaskGraphRunner(
        graph_schema=graph_schema,
        model_storage=default_model_storage,
        execution_context=execution_context,
    )
    results = runner.run(inputs={"first_input": 3, "second_input": 4}, targets=["add"])
    assert results == {"add": 7}
Beispiel #9
0
    def _fingerprint_and_prune(self, train_schema: GraphSchema) -> GraphSchema:
        """Runs the graph using fingerprints to determine which nodes need to re-run.

        Nodes which have a matching fingerprint key in the cache can either be removed
        entirely from the graph, or replaced with a cached value if their output is
        needed by descendent nodes.

        Args:
            train_schema: The train graph schema that will be run in fingerprint mode.

        Returns:
            A new, potentially smaller and/or cached, graph schema.
        """
        fingerprint_schema = self._create_fingerprint_schema(train_schema)

        fingerprint_graph_runner = self._graph_runner_class.create(
            graph_schema=fingerprint_schema,
            model_storage=self._model_storage,
            execution_context=ExecutionContext(
                graph_schema=fingerprint_schema),
        )

        logger.debug("Running the train graph in fingerprint mode.")
        fingerprint_run_outputs = fingerprint_graph_runner.run()

        pruned_training_schema = self._prune_schema(train_schema,
                                                    fingerprint_run_outputs)
        return pruned_training_schema
Beispiel #10
0
def test_loop(default_model_storage: ModelStorage):
    graph_schema = GraphSchema({
        "subtract_a":
        SchemaNode(
            needs={"i": "subtract_b"},
            uses=SubtractByX,
            fn="subtract_x",
            constructor_name="create",
            config={},
            is_target=False,
        ),
        "subtract_b":
        SchemaNode(
            needs={"i": "subtract_a"},
            uses=SubtractByX,
            fn="subtract_x",
            constructor_name="create",
            config={},
            is_target=True,
        ),
    })
    runner = DaskGraphRunner(
        graph_schema=graph_schema,
        model_storage=default_model_storage,
        execution_context=ExecutionContext(graph_schema=graph_schema,
                                           model_id="1"),
    )
    with pytest.raises(GraphRunError):
        runner.run()
Beispiel #11
0
def test_unused_node(default_model_storage: ModelStorage):
    graph_schema = GraphSchema({
        "provide":
        SchemaNode(
            needs={},
            uses=ProvideX,
            fn="provide",
            constructor_name="create",
            config={},
            is_target=True,
        ),
        "provide_2":
        SchemaNode(  # This will not output
            needs={},
            uses=ProvideX,
            fn="provide",
            constructor_name="create",
            config={},
        ),
    })
    runner = DaskGraphRunner(
        graph_schema=graph_schema,
        model_storage=default_model_storage,
        execution_context=ExecutionContext(graph_schema=graph_schema,
                                           model_id="1"),
    )
    results = runner.run()
    assert results == {"provide": 1}
Beispiel #12
0
def load_predict_graph_runner(
    storage_path: Path,
    model_archive_path: Path,
    model_storage_class: Type[ModelStorage],
    graph_runner_class: Type[GraphRunner],
) -> Tuple[ModelMetadata, GraphRunner]:
    """Loads a model from an archive and creates the prediction graph runner.

    Args:
        storage_path: Directory which contains the persisted graph components.
        model_archive_path: The path to the model archive.
        model_storage_class: The class to instantiate the model storage from.
        graph_runner_class: The class to instantiate the runner from.

    Returns:
        A tuple containing the model metadata and the prediction graph runner.
    """
    model_storage, model_metadata = model_storage_class.from_model_archive(
        storage_path=storage_path, model_archive_path=model_archive_path)
    runner = graph_runner_class.create(
        graph_schema=model_metadata.predict_schema,
        model_storage=model_storage,
        execution_context=ExecutionContext(
            graph_schema=model_metadata.predict_schema,
            model_id=model_metadata.model_id),
    )
    return model_metadata, runner
Beispiel #13
0
def test_training_hook_saves_to_cache(
    default_model_storage: ModelStorage,
    temp_cache: TrainingCache,
):
    # We need an execution context so the hook can determine the class of the graph
    # component
    execution_context = ExecutionContext(
        GraphSchema({
            "hello":
            SchemaNode(
                needs={},
                constructor_name="create",
                fn="run",
                config={},
                uses=CacheableComponent,
            )
        }),
        "1",
    )
    node = GraphNode(
        node_name="hello",
        component_class=CacheableComponent,
        constructor_name="create",
        component_config={},
        fn_name="run",
        inputs={"suffix": "input_node"},
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=execution_context,
        hooks=[
            TrainingHook(
                cache=temp_cache,
                model_storage=default_model_storage,
                pruned_schema=execution_context.graph_schema,
            )
        ],
    )

    node(("input_node", "Joe"))

    # This is the same key that the hook will generate
    fingerprint_key = fingerprinting.calculate_fingerprint_key(
        graph_component_class=CacheableComponent,
        config={"prefix": "Hello "},
        inputs={"suffix": "Joe"},
    )

    output_fingerprint_key = temp_cache.get_cached_output_fingerprint(
        fingerprint_key)
    assert output_fingerprint_key

    cached_result = temp_cache.get_cached_result(
        output_fingerprint_key=output_fingerprint_key,
        model_storage=default_model_storage,
        node_name="hello",
    )
    assert isinstance(cached_result, CacheableText)
    assert cached_result.text == "Hello Joe"
Beispiel #14
0
def test_empty_schema(default_model_storage: ModelStorage):
    empty_schema = GraphSchema({})
    runner = DaskGraphRunner(
        graph_schema=empty_schema,
        model_storage=default_model_storage,
        execution_context=ExecutionContext(graph_schema=empty_schema, model_id="1"),
    )
    results = runner.run()
    assert not results
Beispiel #15
0
def test_fingerprint_component_hit(default_model_storage: ModelStorage,
                                   temp_cache: TrainingCache):

    cached_output = CacheableText("Cache me!!")
    output_fingerprint = uuid.uuid4().hex

    # We generate a fingerprint key that will match the one generated by the
    # `FingerprintComponent`.
    component_config = {"x": 1}
    fingerprint_key = fingerprinting.calculate_fingerprint_key(
        graph_component_class=PrecomputedValueProvider,
        config=component_config,
        inputs={
            "param_1": FingerprintableText("input_1"),
            "param_2": FingerprintableText("input_2"),
        },
    )
    # We cache the output using this fingerprint key.
    temp_cache.cache_output(
        fingerprint_key=fingerprint_key,
        output=cached_output,
        output_fingerprint=output_fingerprint,
        model_storage=default_model_storage,
    )

    # The node inputs and config match what we used to generate the fingerprint key.
    node = GraphNode(
        node_name="fingerprint_node",
        component_class=FingerprintComponent,
        constructor_name="create",
        component_config={
            "config_of_replaced_component": component_config,
            "cache": temp_cache,
            "graph_component_class": PrecomputedValueProvider,
        },
        fn_name="run",
        inputs={
            "param_1": "parent_node_1",
            "param_2": "parent_node_2"
        },
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=ExecutionContext(GraphSchema({}), "1"),
    )

    node_name, returned_output = node(
        ("parent_node_1", FingerprintableText("input_1")),
        ("parent_node_2",
         FingerprintStatus(is_hit=True, output_fingerprint="input_2")),
    )

    assert node_name == "fingerprint_node"
    assert returned_output.is_hit is True
    assert returned_output.output_fingerprint == output_fingerprint
    assert returned_output.output_fingerprint == returned_output.fingerprint()
Beispiel #16
0
async def test_sparse_feature_sizes_decreased_incremental_training(
    iter1_path: Text,
    iter2_path: Text,
    should_raise_exception: bool,
    create_response_selector: Callable[
        [Dict[Text, Any]], ResponseSelectorGraphComponent
    ],
    load_response_selector: Callable[[Dict[Text, Any]], ResponseSelectorGraphComponent],
    default_execution_context: ExecutionContext,
    train_and_preprocess: Callable[..., Tuple[TrainingData, List[GraphComponent]]],
    process_message: Callable[..., Message],
):
    pipeline = [
        {"component": WhitespaceTokenizerGraphComponent},
        {"component": LexicalSyntacticFeaturizerGraphComponent},
        {"component": RegexFeaturizerGraphComponent},
        {"component": CountVectorsFeaturizerGraphComponent},
        {
            "component": CountVectorsFeaturizerGraphComponent,
            "analyzer": "char_wb",
            "min_ngram": 1,
            "max_ngram": 4,
        },
    ]
    training_data, loaded_pipeline = train_and_preprocess(pipeline, iter1_path)

    response_selector = create_response_selector({EPOCHS: 1})
    response_selector.train(training_data=training_data)

    message = Message(data={TEXT: "Rasa is great!"})
    message = process_message(loaded_pipeline, message)

    message2 = copy.deepcopy(message)

    classified_message = response_selector.process([message])[0]

    default_execution_context.is_finetuning = True

    loaded_selector = load_response_selector({EPOCHS: 1})

    classified_message2 = loaded_selector.process([message2])[0]

    assert classified_message2.fingerprint() == classified_message.fingerprint()

    if should_raise_exception:
        with pytest.raises(Exception) as exec_info:
            training_data2, loaded_pipeline2 = train_and_preprocess(
                pipeline, iter2_path
            )
            loaded_selector.train(training_data=training_data2)
        assert "Sparse feature sizes have decreased" in str(exec_info.value)
    else:
        training_data2, loaded_pipeline2 = train_and_preprocess(pipeline, iter2_path)
        loaded_selector.train(training_data=training_data2)
        assert loaded_selector.model
Beispiel #17
0
def test_execution_context(default_model_storage: ModelStorage):
    context = ExecutionContext(GraphSchema({}), "some_id")
    node = GraphNode(
        node_name="execution_context_aware",
        component_class=ExecutionContextAware,
        constructor_name="create",
        component_config={},
        fn_name="get_execution_context",
        inputs={},
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=context,
    )

    context.model_id = "a_new_id"

    result = node()[1]
    assert result.model_id == "some_id"
    assert result.node_name == "execution_context_aware"
def featurizer_sparse(tmpdir):
    """Generate a featurizer for tests."""
    node_storage = LocalModelStorage(pathlib.Path(tmpdir))
    node_resource = Resource("sparse_feat")
    context = ExecutionContext(node_storage, node_resource)
    return CountVectorsFeaturizer(
        config=CountVectorsFeaturizer.get_default_config(),
        resource=node_resource,
        model_storage=node_storage,
        execution_context=context,
    )
Beispiel #19
0
async def test_process_gives_diagnostic_data(
    default_execution_context: ExecutionContext,
    create_response_selector: Callable[[Dict[Text, Any]], ResponseSelector],
    train_and_preprocess: Callable[..., Tuple[TrainingData,
                                              List[GraphComponent]]],
    process_message: Callable[..., Message],
):
    """Tests if processing a message returns attention weights as numpy array."""
    pipeline = [
        {
            "component": WhitespaceTokenizer
        },
        {
            "component": CountVectorsFeaturizer
        },
    ]
    config_params = {EPOCHS: 1}

    importer = RasaFileImporter(
        config_file="data/test_response_selector_bot/config.yml",
        domain_path="data/test_response_selector_bot/domain.yml",
        training_data_paths=[
            "data/test_response_selector_bot/data/rules.yml",
            "data/test_response_selector_bot/data/stories.yml",
            "data/test_response_selector_bot/data/nlu.yml",
        ],
    )
    training_data = importer.get_nlu_data()

    training_data, loaded_pipeline = train_and_preprocess(
        pipeline, training_data)

    default_execution_context.should_add_diagnostic_data = True

    response_selector = create_response_selector(config_params)
    response_selector.train(training_data=training_data)

    message = Message(data={TEXT: "hello"})
    message = process_message(loaded_pipeline, message)

    classified_message = response_selector.process([message])[0]
    diagnostic_data = classified_message.get(DIAGNOSTIC_DATA)

    assert isinstance(diagnostic_data, dict)
    for _, values in diagnostic_data.items():
        assert "text_transformed" in values
        assert isinstance(values.get("text_transformed"), np.ndarray)
        # The `attention_weights` key should exist, regardless of there
        # being a transformer
        assert "attention_weights" in values
        # By default, ResponseSelector has `number_of_transformer_layers = 0`
        # in which case the attention weights should be None.
        assert values.get("attention_weights") is None
Beispiel #20
0
 def create_component(
     component_class: Type[GraphComponent], config: Dict[Text, Any], idx: int
 ) -> GraphComponent:
     node_name = f"{component_class.__name__}_{idx}"
     execution_context = ExecutionContext(GraphSchema({}), node_name=node_name)
     resource = Resource(node_name)
     return component_class.create(
         {**component_class.get_default_config(), **config},
         default_model_storage,
         resource,
         execution_context,
     )
Beispiel #21
0
def test_execution_context(default_model_storage: ModelStorage):
    graph_schema = GraphSchema({
        "execution_context_aware":
        SchemaNode(
            needs={},
            uses=ExecutionContextAware,
            fn="get_execution_context",
            constructor_name="create",
            config={},
            is_target=True,
        )
    })
    context = ExecutionContext(graph_schema=graph_schema, model_id="some_id")
    runner = DaskGraphRunner(
        graph_schema=graph_schema,
        model_storage=default_model_storage,
        execution_context=context,
    )
    context.model_id = "a_new_id"
    result = runner.run()["execution_context_aware"]
    assert result.model_id == "some_id"
    assert result.node_name == "execution_context_aware"
Beispiel #22
0
def test_training_hook_does_not_cache_cached_component(
    default_model_storage: ModelStorage,
    temp_cache: TrainingCache,
):
    # We need an execution context so the hook can determine the class of the graph
    # component
    execution_context = ExecutionContext(
        GraphSchema({
            "hello":
            SchemaNode(
                needs={},
                constructor_name="create",
                fn="run",
                config={},
                uses=PrecomputedValueProvider,
            )
        }),
        "1",
    )
    node = GraphNode(
        node_name="hello",
        component_class=PrecomputedValueProvider,
        constructor_name="create",
        component_config={"output": CacheableText("hi")},
        fn_name="get_value",
        inputs={},
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=execution_context,
        hooks=[
            TrainingHook(
                cache=temp_cache,
                model_storage=default_model_storage,
                pruned_schema=execution_context.graph_schema,
            )
        ],
    )

    node(("input_node", "Joe"))

    # This is the same key that the hook will generate
    fingerprint_key = fingerprinting.calculate_fingerprint_key(
        graph_component_class=PrecomputedValueProvider,
        config={"output": CacheableText("hi")},
        inputs={},
    )

    # The hook should not cache the output of a PrecomputedValueProvider
    assert not temp_cache.get_cached_output_fingerprint(fingerprint_key)
def test_predictions_added(training_data, tmpdir, featurizer_sparse):
    """Checks if the sizes are appropriate."""
    # Set up classifier
    node_storage = LocalModelStorage(pathlib.Path(tmpdir))
    node_resource = Resource("classifier")
    context = ExecutionContext(node_storage, node_resource)
    classifier = LogisticRegressionClassifier(
        config=LogisticRegressionClassifier.get_default_config(),
        name=context.node_name,
        resource=node_resource,
        model_storage=node_storage,
    )

    # First we add tokens.
    tokeniser.process(training_data.training_examples)

    # Next we add features.
    featurizer_sparse.train(training_data)
    featurizer_sparse.process(training_data.training_examples)

    # Train the classifier.
    classifier.train(training_data)

    # Make predictions.
    classifier.process(training_data.training_examples)

    # Check that the messages have been processed correctly
    for msg in training_data.training_examples:
        _, conf = msg.get("intent")["name"], msg.get("intent")["confidence"]
        # Confidence should be between 0 and 1.
        assert 0 < conf < 1
        ranking = msg.get("intent_ranking")
        assert is_sorted(ranking)
        assert {i["name"] for i in ranking} == {"greet", "goodbye"}
        # Confirm the sum of confidences is 1.0
        assert np.isclose(np.sum([i["confidence"] for i in ranking]), 1.0)

    classifier.persist()

    loaded_classifier = LogisticRegressionClassifier.load(
        {}, node_storage, node_resource, context
    )

    predicted = copy.copy(training_data)
    actual = copy.copy(training_data)
    loaded_messages = loaded_classifier.process(predicted.training_examples)
    trained_messages = classifier.process(actual.training_examples)
    for m1, m2 in zip(loaded_messages, trained_messages):
        assert m1.get("intent") == m2.get("intent")
Beispiel #24
0
def trained_ted(
    tmp_path_factory: TempPathFactory, moodbot_domain_path: Path,
) -> TEDPolicyGraphComponent:
    training_files = "data/test_moodbot/data/stories.yml"
    domain = Domain.load(moodbot_domain_path)
    trackers = training.load_data(str(training_files), domain)
    policy = TEDPolicyGraphComponent.create(
        {**TEDPolicyGraphComponent.get_default_config(), EPOCHS: 1},
        LocalModelStorage.create(tmp_path_factory.mktemp("storage")),
        Resource("ted"),
        ExecutionContext(GraphSchema({})),
    )
    policy.train(trackers, domain)

    return policy
Beispiel #25
0
def test_config_with_nested_dict_override(default_model_storage: ModelStorage):
    class ComponentWithNestedDictConfig(GraphComponent):
        @staticmethod
        def get_default_config() -> Dict[Text, Any]:
            return {"nested-dict": {"key1": "value1", "key2": "value2"}}

        @classmethod
        def create(
            cls,
            config: Dict,
            model_storage: ModelStorage,
            resource: Resource,
            execution_context: ExecutionContext,
            **kwargs: Any,
        ) -> ComponentWithNestedDictConfig:
            return cls()

        def run(self) -> None:
            return None

    node = GraphNode(
        node_name="nested_dict_config",
        component_class=ComponentWithNestedDictConfig,
        constructor_name="create",
        component_config={"nested-dict": {
            "key2": "override-value2"
        }},
        fn_name="run",
        inputs={},
        eager=True,
        model_storage=default_model_storage,
        resource=None,
        execution_context=ExecutionContext(GraphSchema({}), "123"),
    )

    expected_config = {
        "nested-dict": {
            "key1": "value1",
            "key2": "override-value2"
        }
    }

    for key, value in expected_config.items():
        assert key in node._component_config
        if isinstance(value, dict):
            for nested_key, nested_value in expected_config[key].items():
                assert nested_key in node._component_config[key]
                assert node._component_config[key][nested_key] == nested_value
Beispiel #26
0
def test_can_use_alternate_constructor(default_model_storage: ModelStorage):
    node = GraphNode(
        node_name="provide",
        component_class=ProvideX,
        constructor_name="create_with_2",
        component_config={},
        fn_name="provide",
        inputs={},
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=ExecutionContext(GraphSchema({}), "1"),
    )

    result = node()
    assert result == ("provide", 2)
Beispiel #27
0
def test_component_config(x: Optional[int], output: int,
                          default_model_storage: ModelStorage):
    node = GraphNode(
        node_name="subtract",
        component_class=SubtractByX,
        constructor_name="create",
        component_config={"x": x} if x else {},
        fn_name="subtract_x",
        inputs={"i": "input_node"},
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=ExecutionContext(GraphSchema({}), "1"),
    )

    result = node(("input_node", 5))

    assert result == ("subtract", output)
Beispiel #28
0
def test_eager_and_not_eager(eager: bool, default_model_storage: ModelStorage):
    run_mock = Mock()
    create_mock = Mock()

    class SpyComponent(GraphComponent):
        @classmethod
        def create(
            cls,
            config: Dict,
            model_storage: ModelStorage,
            resource: Resource,
            execution_context: ExecutionContext,
        ) -> SpyComponent:
            create_mock()
            return cls()

        def run(self):
            return run_mock()

    node = GraphNode(
        node_name="spy_node",
        component_class=SpyComponent,
        constructor_name="create",
        component_config={},
        fn_name="run",
        inputs={},
        eager=eager,
        model_storage=default_model_storage,
        resource=None,
        execution_context=ExecutionContext(GraphSchema({}), "1"),
    )

    if eager:
        assert create_mock.called
    else:
        assert not create_mock.called

    assert not run_mock.called

    node()

    assert create_mock.call_count == 1
    assert run_mock.called
Beispiel #29
0
def test_no_target(default_model_storage: ModelStorage):
    graph_schema = GraphSchema({
        "provide":
        SchemaNode(
            needs={},
            uses=ProvideX,
            fn="provide",
            constructor_name="create",
            config={},
        )
    })
    runner = DaskGraphRunner(
        graph_schema=graph_schema,
        model_storage=default_model_storage,
        execution_context=ExecutionContext(graph_schema=graph_schema,
                                           model_id="1"),
    )
    results = runner.run()
    assert not results
Beispiel #30
0
def test_calling_component(default_model_storage: ModelStorage):
    node = GraphNode(
        node_name="add_node",
        component_class=AddInputs,
        constructor_name="create",
        component_config={},
        fn_name="add",
        inputs={
            "i1": "input_node1",
            "i2": "input_node2"
        },
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=ExecutionContext(GraphSchema({}), "1"),
    )

    result = node(("input_node1", 3), ("input_node2", 4))

    assert result == ("add_node", 7)