def test_training_hook_saves_to_cache( default_model_storage: ModelStorage, temp_cache: TrainingCache, ): # We need an execution context so the hook can determine the class of the graph # component execution_context = ExecutionContext( GraphSchema({ "hello": SchemaNode( needs={}, constructor_name="create", fn="run", config={}, uses=CacheableComponent, ) }), "1", ) node = GraphNode( node_name="hello", component_class=CacheableComponent, constructor_name="create", component_config={}, fn_name="run", inputs={"suffix": "input_node"}, eager=False, model_storage=default_model_storage, resource=None, execution_context=execution_context, hooks=[ TrainingHook( cache=temp_cache, model_storage=default_model_storage, pruned_schema=execution_context.graph_schema, ) ], ) node(("input_node", "Joe")) # This is the same key that the hook will generate fingerprint_key = fingerprinting.calculate_fingerprint_key( graph_component_class=CacheableComponent, config={"prefix": "Hello "}, inputs={"suffix": "Joe"}, ) output_fingerprint_key = temp_cache.get_cached_output_fingerprint( fingerprint_key) assert output_fingerprint_key cached_result = temp_cache.get_cached_result( output_fingerprint_key=output_fingerprint_key, model_storage=default_model_storage, node_name="hello", ) assert isinstance(cached_result, CacheableText) assert cached_result.text == "Hello Joe"
def test_get_cached_result_with_miss(temp_cache: TrainingCache, default_model_storage: ModelStorage): # Cache something temp_cache.cache_output( uuid.uuid4().hex, TestCacheableOutput({"something to cache": "dasdaasda"}), uuid.uuid4().hex, default_model_storage, ) assert (temp_cache.get_cached_result(uuid.uuid4().hex, "some node", default_model_storage) is None) assert temp_cache.get_cached_output_fingerprint(uuid.uuid4().hex) is None
def test_cache_output(temp_cache: TrainingCache, default_model_storage: ModelStorage): fingerprint_key = uuid.uuid4().hex output = TestCacheableOutput({"something to cache": "dasdaasda"}) output_fingerprint = uuid.uuid4().hex temp_cache.cache_output(fingerprint_key, output, output_fingerprint, default_model_storage) assert (temp_cache.get_cached_output_fingerprint(fingerprint_key) == output_fingerprint) assert (temp_cache.get_cached_result(output_fingerprint, "some_node", default_model_storage) == output)
def test_training_hook_does_not_cache_cached_component( default_model_storage: ModelStorage, temp_cache: TrainingCache, ): # We need an execution context so the hook can determine the class of the graph # component execution_context = ExecutionContext( GraphSchema({ "hello": SchemaNode( needs={}, constructor_name="create", fn="run", config={}, uses=PrecomputedValueProvider, ) }), "1", ) node = GraphNode( node_name="hello", component_class=PrecomputedValueProvider, constructor_name="create", component_config={"output": CacheableText("hi")}, fn_name="get_value", inputs={}, eager=False, model_storage=default_model_storage, resource=None, execution_context=execution_context, hooks=[ TrainingHook( cache=temp_cache, model_storage=default_model_storage, pruned_schema=execution_context.graph_schema, ) ], ) node(("input_node", "Joe")) # This is the same key that the hook will generate fingerprint_key = fingerprinting.calculate_fingerprint_key( graph_component_class=PrecomputedValueProvider, config={"output": CacheableText("hi")}, inputs={}, ) # The hook should not cache the output of a PrecomputedValueProvider assert not temp_cache.get_cached_output_fingerprint(fingerprint_key)
def test_caching_something_which_is_not_cacheable( temp_cache: TrainingCache, default_model_storage: ModelStorage): # Cache something fingerprint_key = uuid.uuid4().hex output_fingerprint_key = uuid.uuid4().hex temp_cache.cache_output(fingerprint_key, None, output_fingerprint_key, default_model_storage) # Output fingerprint was saved assert (temp_cache.get_cached_output_fingerprint(fingerprint_key) == output_fingerprint_key) # But it's not stored to disk assert (temp_cache.get_cached_result(output_fingerprint_key, "some_node", default_model_storage) is None)
def test_cache_again( temp_cache: TrainingCache, default_model_storage: ModelStorage, initial_output_fingerprint: Text, second_output_fingerprint: Text, ): # Cache something fingerprint_key = uuid.uuid4().hex temp_cache.cache_output(fingerprint_key, None, initial_output_fingerprint, default_model_storage) # Pretend we are caching the same fingerprint again # Note that it can't happen that we cache a `Cacheable` result twice as we would # have replaced the component with a `PrecomputedValueProvider` otherwise temp_cache.cache_output(fingerprint_key, None, second_output_fingerprint, default_model_storage) assert (temp_cache.get_cached_output_fingerprint(fingerprint_key) == second_output_fingerprint)
def test_caching_cacheable_fails( tmp_path: Path, caplog: LogCaptureFixture, temp_cache: TrainingCache, default_model_storage: ModelStorage, ): fingerprint_key = uuid.uuid4().hex # `tmp_path` is not a dict and will hence fail to be cached # noinspection PyTypeChecker output = TestCacheableOutput(tmp_path) output_fingerprint = uuid.uuid4().hex with caplog.at_level(logging.ERROR): temp_cache.cache_output(fingerprint_key, output, output_fingerprint, default_model_storage) assert len(caplog.records) == 1 assert (temp_cache.get_cached_output_fingerprint(fingerprint_key) == output_fingerprint) assert (temp_cache.get_cached_result(output_fingerprint, "some_node", default_model_storage) is None)