Beispiel #1
0
def test_training_hook_saves_to_cache(
    default_model_storage: ModelStorage,
    temp_cache: TrainingCache,
):
    # We need an execution context so the hook can determine the class of the graph
    # component
    execution_context = ExecutionContext(
        GraphSchema({
            "hello":
            SchemaNode(
                needs={},
                constructor_name="create",
                fn="run",
                config={},
                uses=CacheableComponent,
            )
        }),
        "1",
    )
    node = GraphNode(
        node_name="hello",
        component_class=CacheableComponent,
        constructor_name="create",
        component_config={},
        fn_name="run",
        inputs={"suffix": "input_node"},
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=execution_context,
        hooks=[
            TrainingHook(
                cache=temp_cache,
                model_storage=default_model_storage,
                pruned_schema=execution_context.graph_schema,
            )
        ],
    )

    node(("input_node", "Joe"))

    # This is the same key that the hook will generate
    fingerprint_key = fingerprinting.calculate_fingerprint_key(
        graph_component_class=CacheableComponent,
        config={"prefix": "Hello "},
        inputs={"suffix": "Joe"},
    )

    output_fingerprint_key = temp_cache.get_cached_output_fingerprint(
        fingerprint_key)
    assert output_fingerprint_key

    cached_result = temp_cache.get_cached_result(
        output_fingerprint_key=output_fingerprint_key,
        model_storage=default_model_storage,
        node_name="hello",
    )
    assert isinstance(cached_result, CacheableText)
    assert cached_result.text == "Hello Joe"
Beispiel #2
0
def test_get_cached_result_with_miss(temp_cache: TrainingCache,
                                     default_model_storage: ModelStorage):
    # Cache something
    temp_cache.cache_output(
        uuid.uuid4().hex,
        TestCacheableOutput({"something to cache": "dasdaasda"}),
        uuid.uuid4().hex,
        default_model_storage,
    )

    assert (temp_cache.get_cached_result(uuid.uuid4().hex, "some node",
                                         default_model_storage) is None)
    assert temp_cache.get_cached_output_fingerprint(uuid.uuid4().hex) is None
Beispiel #3
0
def test_cache_output(temp_cache: TrainingCache,
                      default_model_storage: ModelStorage):
    fingerprint_key = uuid.uuid4().hex
    output = TestCacheableOutput({"something to cache": "dasdaasda"})
    output_fingerprint = uuid.uuid4().hex

    temp_cache.cache_output(fingerprint_key, output, output_fingerprint,
                            default_model_storage)

    assert (temp_cache.get_cached_output_fingerprint(fingerprint_key) ==
            output_fingerprint)

    assert (temp_cache.get_cached_result(output_fingerprint, "some_node",
                                         default_model_storage) == output)
Beispiel #4
0
def test_training_hook_does_not_cache_cached_component(
    default_model_storage: ModelStorage,
    temp_cache: TrainingCache,
):
    # We need an execution context so the hook can determine the class of the graph
    # component
    execution_context = ExecutionContext(
        GraphSchema({
            "hello":
            SchemaNode(
                needs={},
                constructor_name="create",
                fn="run",
                config={},
                uses=PrecomputedValueProvider,
            )
        }),
        "1",
    )
    node = GraphNode(
        node_name="hello",
        component_class=PrecomputedValueProvider,
        constructor_name="create",
        component_config={"output": CacheableText("hi")},
        fn_name="get_value",
        inputs={},
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=execution_context,
        hooks=[
            TrainingHook(
                cache=temp_cache,
                model_storage=default_model_storage,
                pruned_schema=execution_context.graph_schema,
            )
        ],
    )

    node(("input_node", "Joe"))

    # This is the same key that the hook will generate
    fingerprint_key = fingerprinting.calculate_fingerprint_key(
        graph_component_class=PrecomputedValueProvider,
        config={"output": CacheableText("hi")},
        inputs={},
    )

    # The hook should not cache the output of a PrecomputedValueProvider
    assert not temp_cache.get_cached_output_fingerprint(fingerprint_key)
Beispiel #5
0
def test_caching_something_which_is_not_cacheable(
        temp_cache: TrainingCache, default_model_storage: ModelStorage):
    # Cache something
    fingerprint_key = uuid.uuid4().hex
    output_fingerprint_key = uuid.uuid4().hex
    temp_cache.cache_output(fingerprint_key, None, output_fingerprint_key,
                            default_model_storage)

    # Output fingerprint was saved
    assert (temp_cache.get_cached_output_fingerprint(fingerprint_key) ==
            output_fingerprint_key)

    # But it's not stored to disk
    assert (temp_cache.get_cached_result(output_fingerprint_key, "some_node",
                                         default_model_storage) is None)
Beispiel #6
0
def test_cache_again(
    temp_cache: TrainingCache,
    default_model_storage: ModelStorage,
    initial_output_fingerprint: Text,
    second_output_fingerprint: Text,
):
    # Cache something
    fingerprint_key = uuid.uuid4().hex
    temp_cache.cache_output(fingerprint_key, None, initial_output_fingerprint,
                            default_model_storage)

    # Pretend we are caching the same fingerprint again
    # Note that it can't happen that we cache a `Cacheable` result twice as we would
    # have replaced the component with a `PrecomputedValueProvider` otherwise
    temp_cache.cache_output(fingerprint_key, None, second_output_fingerprint,
                            default_model_storage)

    assert (temp_cache.get_cached_output_fingerprint(fingerprint_key) ==
            second_output_fingerprint)
Beispiel #7
0
def test_caching_cacheable_fails(
    tmp_path: Path,
    caplog: LogCaptureFixture,
    temp_cache: TrainingCache,
    default_model_storage: ModelStorage,
):
    fingerprint_key = uuid.uuid4().hex

    # `tmp_path` is not a dict and will hence fail to be cached
    # noinspection PyTypeChecker
    output = TestCacheableOutput(tmp_path)
    output_fingerprint = uuid.uuid4().hex

    with caplog.at_level(logging.ERROR):
        temp_cache.cache_output(fingerprint_key, output, output_fingerprint,
                                default_model_storage)

    assert len(caplog.records) == 1

    assert (temp_cache.get_cached_output_fingerprint(fingerprint_key) ==
            output_fingerprint)

    assert (temp_cache.get_cached_result(output_fingerprint, "some_node",
                                         default_model_storage) is None)