def test_resource_with_model_storage(default_model_storage: ModelStorage, tmp_path: Path, temp_cache: TrainingCache): node_name = "some node" resource = Resource(node_name) test_filename = "persisted_model.json" test_content = {"epochs": 500} with default_model_storage.write_to(resource) as temporary_directory: rasa.shared.utils.io.dump_obj_as_json_to_file( temporary_directory / test_filename, test_content) test_fingerprint_key = uuid.uuid4().hex test_output_fingerprint_key = uuid.uuid4().hex temp_cache.cache_output( test_fingerprint_key, resource, test_output_fingerprint_key, default_model_storage, ) new_model_storage_location = tmp_path / "new_model_storage" new_model_storage_location.mkdir() new_model_storage = LocalModelStorage(new_model_storage_location) restored_resource = temp_cache.get_cached_result( test_output_fingerprint_key, node_name, new_model_storage) assert isinstance(restored_resource, Resource) assert restored_resource == restored_resource with new_model_storage.read_from(restored_resource) as temporary_directory: cached_content = rasa.shared.utils.io.read_json_file( temporary_directory / test_filename) assert cached_content == test_content
def test_fingerprint_component_hit(default_model_storage: ModelStorage, temp_cache: TrainingCache): cached_output = CacheableText("Cache me!!") output_fingerprint = uuid.uuid4().hex # We generate a fingerprint key that will match the one generated by the # `FingerprintComponent`. component_config = {"x": 1} fingerprint_key = fingerprinting.calculate_fingerprint_key( graph_component_class=PrecomputedValueProvider, config=component_config, inputs={ "param_1": FingerprintableText("input_1"), "param_2": FingerprintableText("input_2"), }, ) # We cache the output using this fingerprint key. temp_cache.cache_output( fingerprint_key=fingerprint_key, output=cached_output, output_fingerprint=output_fingerprint, model_storage=default_model_storage, ) # The node inputs and config match what we used to generate the fingerprint key. node = GraphNode( node_name="fingerprint_node", component_class=FingerprintComponent, constructor_name="create", component_config={ "config_of_replaced_component": component_config, "cache": temp_cache, "graph_component_class": PrecomputedValueProvider, }, fn_name="run", inputs={ "param_1": "parent_node_1", "param_2": "parent_node_2" }, eager=False, model_storage=default_model_storage, resource=None, execution_context=ExecutionContext(GraphSchema({}), "1"), ) node_name, returned_output = node( ("parent_node_1", FingerprintableText("input_1")), ("parent_node_2", FingerprintStatus(is_hit=True, output_fingerprint="input_2")), ) assert node_name == "fingerprint_node" assert returned_output.is_hit is True assert returned_output.output_fingerprint == output_fingerprint assert returned_output.output_fingerprint == returned_output.fingerprint()
def test_get_cached_result_with_miss(temp_cache: TrainingCache, default_model_storage: ModelStorage): # Cache something temp_cache.cache_output( uuid.uuid4().hex, TestCacheableOutput({"something to cache": "dasdaasda"}), uuid.uuid4().hex, default_model_storage, ) assert (temp_cache.get_cached_result(uuid.uuid4().hex, "some node", default_model_storage) is None) assert temp_cache.get_cached_output_fingerprint(uuid.uuid4().hex) is None
def test_cache_output(temp_cache: TrainingCache, default_model_storage: ModelStorage): fingerprint_key = uuid.uuid4().hex output = TestCacheableOutput({"something to cache": "dasdaasda"}) output_fingerprint = uuid.uuid4().hex temp_cache.cache_output(fingerprint_key, output, output_fingerprint, default_model_storage) assert (temp_cache.get_cached_output_fingerprint(fingerprint_key) == output_fingerprint) assert (temp_cache.get_cached_result(output_fingerprint, "some_node", default_model_storage) == output)
def test_caching_something_which_is_not_cacheable( temp_cache: TrainingCache, default_model_storage: ModelStorage): # Cache something fingerprint_key = uuid.uuid4().hex output_fingerprint_key = uuid.uuid4().hex temp_cache.cache_output(fingerprint_key, None, output_fingerprint_key, default_model_storage) # Output fingerprint was saved assert (temp_cache.get_cached_output_fingerprint(fingerprint_key) == output_fingerprint_key) # But it's not stored to disk assert (temp_cache.get_cached_result(output_fingerprint_key, "some_node", default_model_storage) is None)
def test_restore_cached_output_with_invalid_module( temp_cache: TrainingCache, default_model_storage: ModelStorage, monkeypatch: MonkeyPatch, cached_module: Any, ): output = TestCacheableOutput({"something to cache": "dasdaasda"}) output_fingerprint = uuid.uuid4().hex temp_cache.cache_output(uuid.uuid4().hex, output, output_fingerprint, default_model_storage) monkeypatch.setattr(rasa.shared.utils.common, "class_from_module_path", cached_module) assert (temp_cache.get_cached_result(output_fingerprint, "some_node", default_model_storage) is None)
def test_cache_again( temp_cache: TrainingCache, default_model_storage: ModelStorage, initial_output_fingerprint: Text, second_output_fingerprint: Text, ): # Cache something fingerprint_key = uuid.uuid4().hex temp_cache.cache_output(fingerprint_key, None, initial_output_fingerprint, default_model_storage) # Pretend we are caching the same fingerprint again # Note that it can't happen that we cache a `Cacheable` result twice as we would # have replaced the component with a `PrecomputedValueProvider` otherwise temp_cache.cache_output(fingerprint_key, None, second_output_fingerprint, default_model_storage) assert (temp_cache.get_cached_output_fingerprint(fingerprint_key) == second_output_fingerprint)
def test_caching_cacheable_fails( tmp_path: Path, caplog: LogCaptureFixture, temp_cache: TrainingCache, default_model_storage: ModelStorage, ): fingerprint_key = uuid.uuid4().hex # `tmp_path` is not a dict and will hence fail to be cached # noinspection PyTypeChecker output = TestCacheableOutput(tmp_path) output_fingerprint = uuid.uuid4().hex with caplog.at_level(logging.ERROR): temp_cache.cache_output(fingerprint_key, output, output_fingerprint, default_model_storage) assert len(caplog.records) == 1 assert (temp_cache.get_cached_output_fingerprint(fingerprint_key) == output_fingerprint) assert (temp_cache.get_cached_result(output_fingerprint, "some_node", default_model_storage) is None)