Ejemplo n.º 1
0
def test_fn_exception(default_model_storage: ModelStorage):
    class BadFn(GraphComponent):
        @classmethod
        def create(
            cls,
            config: Dict[Text, Any],
            model_storage: ModelStorage,
            resource: Resource,
            execution_context: ExecutionContext,
        ) -> BadFn:
            return cls()

        def run(self) -> None:
            raise ValueError("Oh no!")

    node = GraphNode(
        node_name="bad_fn",
        component_class=BadFn,
        constructor_name="create",
        component_config={},
        fn_name="run",
        inputs={},
        eager=True,
        model_storage=default_model_storage,
        resource=None,
        execution_context=ExecutionContext(GraphSchema({}), "some_id"),
    )

    with pytest.raises(GraphComponentException):
        node()
Ejemplo n.º 2
0
def test_loading_from_resource_eager(default_model_storage: ModelStorage):
    previous_resource = Resource("previous resource")
    test_value = {"test": "test value"}

    # Pretend resource persisted itself before
    with default_model_storage.write_to(previous_resource) as directory:
        rasa.shared.utils.io.dump_obj_as_json_to_file(directory / "test.json",
                                                      test_value)

    node_name = "some_name"
    node = GraphNode(
        node_name=node_name,
        component_class=PersistableTestComponent,
        constructor_name="load",
        component_config={},
        fn_name="run_inference",
        inputs={},
        eager=True,
        model_storage=default_model_storage,
        # The `GraphComponent` should load from this resource
        resource=previous_resource,
        execution_context=ExecutionContext(GraphSchema({}), "123"),
    )

    actual_node_name, value = node()

    assert actual_node_name == node_name
    assert value == test_value
Ejemplo n.º 3
0
def test_writing_to_resource_during_training(
        default_model_storage: ModelStorage):
    node_name = "some_name"

    test_value_for_sub_directory = {"test": "test value sub dir"}
    test_value = {"test dir": "test value dir"}

    node = GraphNode(
        node_name=node_name,
        component_class=PersistableTestComponent,
        constructor_name="create",
        component_config={
            "test_value": test_value,
            "test_value_for_sub_directory": test_value_for_sub_directory,
        },
        fn_name="train",
        inputs={},
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=ExecutionContext(GraphSchema({}), "123"),
    )

    _, resource = node()

    assert resource == Resource(node_name)

    with default_model_storage.read_from(resource) as directory:
        assert (rasa.shared.utils.io.read_json_file(directory /
                                                    "test.json") == test_value)
        assert (rasa.shared.utils.io.read_json_file(
            directory / "sub_dir" /
            "test.json") == test_value_for_sub_directory)
Ejemplo n.º 4
0
def test_training_hook_saves_to_cache(
    default_model_storage: ModelStorage,
    temp_cache: TrainingCache,
):
    # We need an execution context so the hook can determine the class of the graph
    # component
    execution_context = ExecutionContext(
        GraphSchema({
            "hello":
            SchemaNode(
                needs={},
                constructor_name="create",
                fn="run",
                config={},
                uses=CacheableComponent,
            )
        }),
        "1",
    )
    node = GraphNode(
        node_name="hello",
        component_class=CacheableComponent,
        constructor_name="create",
        component_config={},
        fn_name="run",
        inputs={"suffix": "input_node"},
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=execution_context,
        hooks=[
            TrainingHook(
                cache=temp_cache,
                model_storage=default_model_storage,
                pruned_schema=execution_context.graph_schema,
            )
        ],
    )

    node(("input_node", "Joe"))

    # This is the same key that the hook will generate
    fingerprint_key = fingerprinting.calculate_fingerprint_key(
        graph_component_class=CacheableComponent,
        config={"prefix": "Hello "},
        inputs={"suffix": "Joe"},
    )

    output_fingerprint_key = temp_cache.get_cached_output_fingerprint(
        fingerprint_key)
    assert output_fingerprint_key

    cached_result = temp_cache.get_cached_result(
        output_fingerprint_key=output_fingerprint_key,
        model_storage=default_model_storage,
        node_name="hello",
    )
    assert isinstance(cached_result, CacheableText)
    assert cached_result.text == "Hello Joe"
Ejemplo n.º 5
0
def test_fingerprint_component_hit(default_model_storage: ModelStorage,
                                   temp_cache: TrainingCache):

    cached_output = CacheableText("Cache me!!")
    output_fingerprint = uuid.uuid4().hex

    # We generate a fingerprint key that will match the one generated by the
    # `FingerprintComponent`.
    component_config = {"x": 1}
    fingerprint_key = fingerprinting.calculate_fingerprint_key(
        graph_component_class=PrecomputedValueProvider,
        config=component_config,
        inputs={
            "param_1": FingerprintableText("input_1"),
            "param_2": FingerprintableText("input_2"),
        },
    )
    # We cache the output using this fingerprint key.
    temp_cache.cache_output(
        fingerprint_key=fingerprint_key,
        output=cached_output,
        output_fingerprint=output_fingerprint,
        model_storage=default_model_storage,
    )

    # The node inputs and config match what we used to generate the fingerprint key.
    node = GraphNode(
        node_name="fingerprint_node",
        component_class=FingerprintComponent,
        constructor_name="create",
        component_config={
            "config_of_replaced_component": component_config,
            "cache": temp_cache,
            "graph_component_class": PrecomputedValueProvider,
        },
        fn_name="run",
        inputs={
            "param_1": "parent_node_1",
            "param_2": "parent_node_2"
        },
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=ExecutionContext(GraphSchema({}), "1"),
    )

    node_name, returned_output = node(
        ("parent_node_1", FingerprintableText("input_1")),
        ("parent_node_2",
         FingerprintStatus(is_hit=True, output_fingerprint="input_2")),
    )

    assert node_name == "fingerprint_node"
    assert returned_output.is_hit is True
    assert returned_output.output_fingerprint == output_fingerprint
    assert returned_output.output_fingerprint == returned_output.fingerprint()
Ejemplo n.º 6
0
 def _instantiate_nodes(
     graph_schema: GraphSchema,
     model_storage: ModelStorage,
     execution_context: ExecutionContext,
     hooks: Optional[List[GraphNodeHook]] = None,
 ) -> Dict[Text, GraphNode]:
     return {
         node_name:
         GraphNode.from_schema_node(node_name, schema_node, model_storage,
                                    execution_context, hooks)
         for node_name, schema_node in graph_schema.nodes.items()
     }
Ejemplo n.º 7
0
def test_training_hook_does_not_cache_cached_component(
    default_model_storage: ModelStorage,
    temp_cache: TrainingCache,
):
    # We need an execution context so the hook can determine the class of the graph
    # component
    execution_context = ExecutionContext(
        GraphSchema({
            "hello":
            SchemaNode(
                needs={},
                constructor_name="create",
                fn="run",
                config={},
                uses=PrecomputedValueProvider,
            )
        }),
        "1",
    )
    node = GraphNode(
        node_name="hello",
        component_class=PrecomputedValueProvider,
        constructor_name="create",
        component_config={"output": CacheableText("hi")},
        fn_name="get_value",
        inputs={},
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=execution_context,
        hooks=[
            TrainingHook(
                cache=temp_cache,
                model_storage=default_model_storage,
                pruned_schema=execution_context.graph_schema,
            )
        ],
    )

    node(("input_node", "Joe"))

    # This is the same key that the hook will generate
    fingerprint_key = fingerprinting.calculate_fingerprint_key(
        graph_component_class=PrecomputedValueProvider,
        config={"output": CacheableText("hi")},
        inputs={},
    )

    # The hook should not cache the output of a PrecomputedValueProvider
    assert not temp_cache.get_cached_output_fingerprint(fingerprint_key)
Ejemplo n.º 8
0
def test_config_with_nested_dict_override(default_model_storage: ModelStorage):
    class ComponentWithNestedDictConfig(GraphComponent):
        @staticmethod
        def get_default_config() -> Dict[Text, Any]:
            return {"nested-dict": {"key1": "value1", "key2": "value2"}}

        @classmethod
        def create(
            cls,
            config: Dict,
            model_storage: ModelStorage,
            resource: Resource,
            execution_context: ExecutionContext,
            **kwargs: Any,
        ) -> ComponentWithNestedDictConfig:
            return cls()

        def run(self) -> None:
            return None

    node = GraphNode(
        node_name="nested_dict_config",
        component_class=ComponentWithNestedDictConfig,
        constructor_name="create",
        component_config={"nested-dict": {
            "key2": "override-value2"
        }},
        fn_name="run",
        inputs={},
        eager=True,
        model_storage=default_model_storage,
        resource=None,
        execution_context=ExecutionContext(GraphSchema({}), "123"),
    )

    expected_config = {
        "nested-dict": {
            "key1": "value1",
            "key2": "override-value2"
        }
    }

    for key, value in expected_config.items():
        assert key in node._component_config
        if isinstance(value, dict):
            for nested_key, nested_value in expected_config[key].items():
                assert nested_key in node._component_config[key]
                assert node._component_config[key][nested_key] == nested_value
Ejemplo n.º 9
0
def test_can_use_alternate_constructor(default_model_storage: ModelStorage):
    node = GraphNode(
        node_name="provide",
        component_class=ProvideX,
        constructor_name="create_with_2",
        component_config={},
        fn_name="provide",
        inputs={},
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=ExecutionContext(GraphSchema({}), "1"),
    )

    result = node()
    assert result == ("provide", 2)
Ejemplo n.º 10
0
def test_exception_handling_for_on_before_hook(
    on_before: Callable,
    on_after: Callable,
    default_model_storage: ModelStorage,
    default_execution_context: ExecutionContext,
):
    schema_node = SchemaNode(
        needs={},
        uses=ProvideX,
        fn="provide",
        constructor_name="create",
        config={},
    )

    class MyHook(GraphNodeHook):
        def on_after_node(
            self,
            node_name: Text,
            execution_context: ExecutionContext,
            config: Dict[Text, Any],
            output: Any,
            input_hook_data: Dict,
        ) -> None:
            on_before()

        def on_before_node(
            self,
            node_name: Text,
            execution_context: ExecutionContext,
            config: Dict[Text, Any],
            received_inputs: Dict[Text, Any],
        ) -> Dict:
            on_after()
            return {}

    node = GraphNode.from_schema_node(
        "some_node",
        schema_node,
        default_model_storage,
        default_execution_context,
        hooks=[MyHook()],
    )

    with pytest.raises(GraphComponentException):
        node()
Ejemplo n.º 11
0
def test_component_config(x: Optional[int], output: int,
                          default_model_storage: ModelStorage):
    node = GraphNode(
        node_name="subtract",
        component_class=SubtractByX,
        constructor_name="create",
        component_config={"x": x} if x else {},
        fn_name="subtract_x",
        inputs={"i": "input_node"},
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=ExecutionContext(GraphSchema({}), "1"),
    )

    result = node(("input_node", 5))

    assert result == ("subtract", output)
Ejemplo n.º 12
0
def test_eager_and_not_eager(eager: bool, default_model_storage: ModelStorage):
    run_mock = Mock()
    create_mock = Mock()

    class SpyComponent(GraphComponent):
        @classmethod
        def create(
            cls,
            config: Dict,
            model_storage: ModelStorage,
            resource: Resource,
            execution_context: ExecutionContext,
        ) -> SpyComponent:
            create_mock()
            return cls()

        def run(self):
            return run_mock()

    node = GraphNode(
        node_name="spy_node",
        component_class=SpyComponent,
        constructor_name="create",
        component_config={},
        fn_name="run",
        inputs={},
        eager=eager,
        model_storage=default_model_storage,
        resource=None,
        execution_context=ExecutionContext(GraphSchema({}), "1"),
    )

    if eager:
        assert create_mock.called
    else:
        assert not create_mock.called

    assert not run_mock.called

    node()

    assert create_mock.call_count == 1
    assert run_mock.called
Ejemplo n.º 13
0
def test_calling_component(default_model_storage: ModelStorage):
    node = GraphNode(
        node_name="add_node",
        component_class=AddInputs,
        constructor_name="create",
        component_config={},
        fn_name="add",
        inputs={
            "i1": "input_node1",
            "i2": "input_node2"
        },
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=ExecutionContext(GraphSchema({}), "1"),
    )

    result = node(("input_node1", 3), ("input_node2", 4))

    assert result == ("add_node", 7)
Ejemplo n.º 14
0
def test_execution_context(default_model_storage: ModelStorage):
    context = ExecutionContext(GraphSchema({}), "some_id")
    node = GraphNode(
        node_name="execution_context_aware",
        component_class=ExecutionContextAware,
        constructor_name="create",
        component_config={},
        fn_name="get_execution_context",
        inputs={},
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=context,
    )

    context.model_id = "a_new_id"

    result = node()[1]
    assert result.model_id == "some_id"
    assert result.node_name == "execution_context_aware"
Ejemplo n.º 15
0
def test_cached_component_returns_value_from_cache(
        default_model_storage: ModelStorage):

    cached_output = CacheableText("Cache me!!")

    node = GraphNode(
        node_name="cached",
        component_class=PrecomputedValueProvider,
        constructor_name="create",
        component_config={"output": cached_output},
        fn_name="get_value",
        inputs={},
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=ExecutionContext(GraphSchema({}), "1"),
    )

    node_name, returned_output = node()

    assert node_name == "cached"
    assert returned_output.text == "Cache me!!"
Ejemplo n.º 16
0
def test_fingerprint_component_miss(default_model_storage: ModelStorage,
                                    temp_cache: TrainingCache):

    component_config = {"x": 1}

    node = GraphNode(
        node_name="fingerprint_node",
        component_class=FingerprintComponent,
        constructor_name="create",
        component_config={
            "config_of_replaced_component": component_config,
            "cache": temp_cache,
            "graph_component_class": PrecomputedValueProvider,
        },
        fn_name="run",
        inputs={
            "param_1": "parent_node_1",
            "param_2": "parent_node_2"
        },
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=ExecutionContext(GraphSchema({}), "1"),
    )

    node_name, returned_output = node(
        ("parent_node_1", FingerprintableText("input_1")),
        ("parent_node_2",
         FingerprintStatus(is_hit=True, output_fingerprint="input_2")),
    )

    # As we didnt add anything to the cache, it cannot be a hit.
    assert node_name == "fingerprint_node"
    assert returned_output.is_hit is False
    assert returned_output.output_fingerprint is None
    assert returned_output.fingerprint() != returned_output.output_fingerprint
    assert returned_output.fingerprint() != returned_output.fingerprint()