Beispiel #1
0
def test_fingerprint_stays_same():
    key1 = fingerprinting.calculate_fingerprint_key(
        TEDPolicy, TEDPolicy.get_default_config(),
        {"input": FingerprintableText("Hi")})
    key2 = fingerprinting.calculate_fingerprint_key(
        TEDPolicy, TEDPolicy.get_default_config(),
        {"input": FingerprintableText("Hi")})

    assert key1 == key2
Beispiel #2
0
def test_fingerprint_changes_due_to_inputs():
    key1 = fingerprinting.calculate_fingerprint_key(
        TEDPolicy, {}, {"input": FingerprintableText("Hi")})
    key2 = fingerprinting.calculate_fingerprint_key(
        ResponseSelector,
        TEDPolicy.get_default_config(),
        {"input": FingerprintableText("bye")},
    )

    assert key1 != key2
Beispiel #3
0
def test_fingerprint_changes_due_to_changed_source(monkeypatch: MonkeyPatch):
    key1 = fingerprinting.calculate_fingerprint_key(
        TEDPolicy, {}, {"input": FingerprintableText("Hi")})

    get_source_mock = Mock(return_value="other implementation")
    monkeypatch.setattr(inspect, inspect.getsource.__name__, get_source_mock)

    key2 = fingerprinting.calculate_fingerprint_key(
        TEDPolicy, {}, {"input": FingerprintableText("Hi")})

    assert key1 != key2

    get_source_mock.assert_called_once_with(TEDPolicy)
Beispiel #4
0
    def run(self, **kwargs: Any) -> FingerprintStatus:
        """Calculates the fingerprint key to determine if cached output can be used.

        If the fingerprint key matches an entry in the cache it means that there has
        been a previous node execution which matches the same component class, component
        config and input values. This means that we can potentially prune this node
        from the schema, or replace it with a cached value before the next graph run.

        Args:
            **kwargs: Inputs from all parent nodes.

        Returns:
            A `FingerprintStatus` determining if the run was a hit, and if it was a hit
            also the output fingerprint from the cache.
        """
        fingerprint_key = fingerprinting.calculate_fingerprint_key(
            graph_component_class=self._class_of_replaced_component,
            config={
                **self._class_of_replaced_component.get_default_config(),
                **self._config_of_replaced_component,
            },
            inputs=kwargs,
        )

        output_fingerprint = self._cache.get_cached_output_fingerprint(
            fingerprint_key)

        return FingerprintStatus(is_hit=output_fingerprint is not None,
                                 output_fingerprint=output_fingerprint)
Beispiel #5
0
def test_training_hook_saves_to_cache(
    default_model_storage: ModelStorage,
    temp_cache: TrainingCache,
):
    # We need an execution context so the hook can determine the class of the graph
    # component
    execution_context = ExecutionContext(
        GraphSchema({
            "hello":
            SchemaNode(
                needs={},
                constructor_name="create",
                fn="run",
                config={},
                uses=CacheableComponent,
            )
        }),
        "1",
    )
    node = GraphNode(
        node_name="hello",
        component_class=CacheableComponent,
        constructor_name="create",
        component_config={},
        fn_name="run",
        inputs={"suffix": "input_node"},
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=execution_context,
        hooks=[
            TrainingHook(
                cache=temp_cache,
                model_storage=default_model_storage,
                pruned_schema=execution_context.graph_schema,
            )
        ],
    )

    node(("input_node", "Joe"))

    # This is the same key that the hook will generate
    fingerprint_key = fingerprinting.calculate_fingerprint_key(
        graph_component_class=CacheableComponent,
        config={"prefix": "Hello "},
        inputs={"suffix": "Joe"},
    )

    output_fingerprint_key = temp_cache.get_cached_output_fingerprint(
        fingerprint_key)
    assert output_fingerprint_key

    cached_result = temp_cache.get_cached_result(
        output_fingerprint_key=output_fingerprint_key,
        model_storage=default_model_storage,
        node_name="hello",
    )
    assert isinstance(cached_result, CacheableText)
    assert cached_result.text == "Hello Joe"
Beispiel #6
0
def test_fingerprint_component_hit(default_model_storage: ModelStorage,
                                   temp_cache: TrainingCache):

    cached_output = CacheableText("Cache me!!")
    output_fingerprint = uuid.uuid4().hex

    # We generate a fingerprint key that will match the one generated by the
    # `FingerprintComponent`.
    component_config = {"x": 1}
    fingerprint_key = fingerprinting.calculate_fingerprint_key(
        graph_component_class=PrecomputedValueProvider,
        config=component_config,
        inputs={
            "param_1": FingerprintableText("input_1"),
            "param_2": FingerprintableText("input_2"),
        },
    )
    # We cache the output using this fingerprint key.
    temp_cache.cache_output(
        fingerprint_key=fingerprint_key,
        output=cached_output,
        output_fingerprint=output_fingerprint,
        model_storage=default_model_storage,
    )

    # The node inputs and config match what we used to generate the fingerprint key.
    node = GraphNode(
        node_name="fingerprint_node",
        component_class=FingerprintComponent,
        constructor_name="create",
        component_config={
            "config_of_replaced_component": component_config,
            "cache": temp_cache,
            "graph_component_class": PrecomputedValueProvider,
        },
        fn_name="run",
        inputs={
            "param_1": "parent_node_1",
            "param_2": "parent_node_2"
        },
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=ExecutionContext(GraphSchema({}), "1"),
    )

    node_name, returned_output = node(
        ("parent_node_1", FingerprintableText("input_1")),
        ("parent_node_2",
         FingerprintStatus(is_hit=True, output_fingerprint="input_2")),
    )

    assert node_name == "fingerprint_node"
    assert returned_output.is_hit is True
    assert returned_output.output_fingerprint == output_fingerprint
    assert returned_output.output_fingerprint == returned_output.fingerprint()
Beispiel #7
0
def test_training_hook_does_not_cache_cached_component(
    default_model_storage: ModelStorage,
    temp_cache: TrainingCache,
):
    # We need an execution context so the hook can determine the class of the graph
    # component
    execution_context = ExecutionContext(
        GraphSchema({
            "hello":
            SchemaNode(
                needs={},
                constructor_name="create",
                fn="run",
                config={},
                uses=PrecomputedValueProvider,
            )
        }),
        "1",
    )
    node = GraphNode(
        node_name="hello",
        component_class=PrecomputedValueProvider,
        constructor_name="create",
        component_config={"output": CacheableText("hi")},
        fn_name="get_value",
        inputs={},
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=execution_context,
        hooks=[
            TrainingHook(
                cache=temp_cache,
                model_storage=default_model_storage,
                pruned_schema=execution_context.graph_schema,
            )
        ],
    )

    node(("input_node", "Joe"))

    # This is the same key that the hook will generate
    fingerprint_key = fingerprinting.calculate_fingerprint_key(
        graph_component_class=PrecomputedValueProvider,
        config={"output": CacheableText("hi")},
        inputs={},
    )

    # The hook should not cache the output of a PrecomputedValueProvider
    assert not temp_cache.get_cached_output_fingerprint(fingerprint_key)
Beispiel #8
0
    def on_before_node(
        self,
        node_name: Text,
        execution_context: ExecutionContext,
        config: Dict[Text, Any],
        received_inputs: Dict[Text, Any],
    ) -> Dict:
        """Calculates the run fingerprint for use in `on_after_node`."""
        graph_component_class = self._get_graph_component_class(
            execution_context, node_name)
        fingerprint_key = fingerprinting.calculate_fingerprint_key(
            graph_component_class=graph_component_class,
            config=config,
            inputs=received_inputs,
        )

        return {"fingerprint_key": fingerprint_key}