def test_fn_exception(default_model_storage: ModelStorage): class BadFn(GraphComponent): @classmethod def create( cls, config: Dict[Text, Any], model_storage: ModelStorage, resource: Resource, execution_context: ExecutionContext, ) -> BadFn: return cls() def run(self) -> None: raise ValueError("Oh no!") node = GraphNode( node_name="bad_fn", component_class=BadFn, constructor_name="create", component_config={}, fn_name="run", inputs={}, eager=True, model_storage=default_model_storage, resource=None, execution_context=ExecutionContext(GraphSchema({}), "some_id"), ) with pytest.raises(GraphComponentException): node()
def test_loading_from_resource_eager(default_model_storage: ModelStorage): previous_resource = Resource("previous resource") test_value = {"test": "test value"} # Pretend resource persisted itself before with default_model_storage.write_to(previous_resource) as directory: rasa.shared.utils.io.dump_obj_as_json_to_file(directory / "test.json", test_value) node_name = "some_name" node = GraphNode( node_name=node_name, component_class=PersistableTestComponent, constructor_name="load", component_config={}, fn_name="run_inference", inputs={}, eager=True, model_storage=default_model_storage, # The `GraphComponent` should load from this resource resource=previous_resource, execution_context=ExecutionContext(GraphSchema({}), "123"), ) actual_node_name, value = node() assert actual_node_name == node_name assert value == test_value
def test_writing_to_resource_during_training( default_model_storage: ModelStorage): node_name = "some_name" test_value_for_sub_directory = {"test": "test value sub dir"} test_value = {"test dir": "test value dir"} node = GraphNode( node_name=node_name, component_class=PersistableTestComponent, constructor_name="create", component_config={ "test_value": test_value, "test_value_for_sub_directory": test_value_for_sub_directory, }, fn_name="train", inputs={}, eager=False, model_storage=default_model_storage, resource=None, execution_context=ExecutionContext(GraphSchema({}), "123"), ) _, resource = node() assert resource == Resource(node_name) with default_model_storage.read_from(resource) as directory: assert (rasa.shared.utils.io.read_json_file(directory / "test.json") == test_value) assert (rasa.shared.utils.io.read_json_file( directory / "sub_dir" / "test.json") == test_value_for_sub_directory)
def test_training_hook_saves_to_cache( default_model_storage: ModelStorage, temp_cache: TrainingCache, ): # We need an execution context so the hook can determine the class of the graph # component execution_context = ExecutionContext( GraphSchema({ "hello": SchemaNode( needs={}, constructor_name="create", fn="run", config={}, uses=CacheableComponent, ) }), "1", ) node = GraphNode( node_name="hello", component_class=CacheableComponent, constructor_name="create", component_config={}, fn_name="run", inputs={"suffix": "input_node"}, eager=False, model_storage=default_model_storage, resource=None, execution_context=execution_context, hooks=[ TrainingHook( cache=temp_cache, model_storage=default_model_storage, pruned_schema=execution_context.graph_schema, ) ], ) node(("input_node", "Joe")) # This is the same key that the hook will generate fingerprint_key = fingerprinting.calculate_fingerprint_key( graph_component_class=CacheableComponent, config={"prefix": "Hello "}, inputs={"suffix": "Joe"}, ) output_fingerprint_key = temp_cache.get_cached_output_fingerprint( fingerprint_key) assert output_fingerprint_key cached_result = temp_cache.get_cached_result( output_fingerprint_key=output_fingerprint_key, model_storage=default_model_storage, node_name="hello", ) assert isinstance(cached_result, CacheableText) assert cached_result.text == "Hello Joe"
def test_fingerprint_component_hit(default_model_storage: ModelStorage, temp_cache: TrainingCache): cached_output = CacheableText("Cache me!!") output_fingerprint = uuid.uuid4().hex # We generate a fingerprint key that will match the one generated by the # `FingerprintComponent`. component_config = {"x": 1} fingerprint_key = fingerprinting.calculate_fingerprint_key( graph_component_class=PrecomputedValueProvider, config=component_config, inputs={ "param_1": FingerprintableText("input_1"), "param_2": FingerprintableText("input_2"), }, ) # We cache the output using this fingerprint key. temp_cache.cache_output( fingerprint_key=fingerprint_key, output=cached_output, output_fingerprint=output_fingerprint, model_storage=default_model_storage, ) # The node inputs and config match what we used to generate the fingerprint key. node = GraphNode( node_name="fingerprint_node", component_class=FingerprintComponent, constructor_name="create", component_config={ "config_of_replaced_component": component_config, "cache": temp_cache, "graph_component_class": PrecomputedValueProvider, }, fn_name="run", inputs={ "param_1": "parent_node_1", "param_2": "parent_node_2" }, eager=False, model_storage=default_model_storage, resource=None, execution_context=ExecutionContext(GraphSchema({}), "1"), ) node_name, returned_output = node( ("parent_node_1", FingerprintableText("input_1")), ("parent_node_2", FingerprintStatus(is_hit=True, output_fingerprint="input_2")), ) assert node_name == "fingerprint_node" assert returned_output.is_hit is True assert returned_output.output_fingerprint == output_fingerprint assert returned_output.output_fingerprint == returned_output.fingerprint()
def _instantiate_nodes( graph_schema: GraphSchema, model_storage: ModelStorage, execution_context: ExecutionContext, hooks: Optional[List[GraphNodeHook]] = None, ) -> Dict[Text, GraphNode]: return { node_name: GraphNode.from_schema_node(node_name, schema_node, model_storage, execution_context, hooks) for node_name, schema_node in graph_schema.nodes.items() }
def test_training_hook_does_not_cache_cached_component( default_model_storage: ModelStorage, temp_cache: TrainingCache, ): # We need an execution context so the hook can determine the class of the graph # component execution_context = ExecutionContext( GraphSchema({ "hello": SchemaNode( needs={}, constructor_name="create", fn="run", config={}, uses=PrecomputedValueProvider, ) }), "1", ) node = GraphNode( node_name="hello", component_class=PrecomputedValueProvider, constructor_name="create", component_config={"output": CacheableText("hi")}, fn_name="get_value", inputs={}, eager=False, model_storage=default_model_storage, resource=None, execution_context=execution_context, hooks=[ TrainingHook( cache=temp_cache, model_storage=default_model_storage, pruned_schema=execution_context.graph_schema, ) ], ) node(("input_node", "Joe")) # This is the same key that the hook will generate fingerprint_key = fingerprinting.calculate_fingerprint_key( graph_component_class=PrecomputedValueProvider, config={"output": CacheableText("hi")}, inputs={}, ) # The hook should not cache the output of a PrecomputedValueProvider assert not temp_cache.get_cached_output_fingerprint(fingerprint_key)
def test_config_with_nested_dict_override(default_model_storage: ModelStorage): class ComponentWithNestedDictConfig(GraphComponent): @staticmethod def get_default_config() -> Dict[Text, Any]: return {"nested-dict": {"key1": "value1", "key2": "value2"}} @classmethod def create( cls, config: Dict, model_storage: ModelStorage, resource: Resource, execution_context: ExecutionContext, **kwargs: Any, ) -> ComponentWithNestedDictConfig: return cls() def run(self) -> None: return None node = GraphNode( node_name="nested_dict_config", component_class=ComponentWithNestedDictConfig, constructor_name="create", component_config={"nested-dict": { "key2": "override-value2" }}, fn_name="run", inputs={}, eager=True, model_storage=default_model_storage, resource=None, execution_context=ExecutionContext(GraphSchema({}), "123"), ) expected_config = { "nested-dict": { "key1": "value1", "key2": "override-value2" } } for key, value in expected_config.items(): assert key in node._component_config if isinstance(value, dict): for nested_key, nested_value in expected_config[key].items(): assert nested_key in node._component_config[key] assert node._component_config[key][nested_key] == nested_value
def test_can_use_alternate_constructor(default_model_storage: ModelStorage): node = GraphNode( node_name="provide", component_class=ProvideX, constructor_name="create_with_2", component_config={}, fn_name="provide", inputs={}, eager=False, model_storage=default_model_storage, resource=None, execution_context=ExecutionContext(GraphSchema({}), "1"), ) result = node() assert result == ("provide", 2)
def test_exception_handling_for_on_before_hook( on_before: Callable, on_after: Callable, default_model_storage: ModelStorage, default_execution_context: ExecutionContext, ): schema_node = SchemaNode( needs={}, uses=ProvideX, fn="provide", constructor_name="create", config={}, ) class MyHook(GraphNodeHook): def on_after_node( self, node_name: Text, execution_context: ExecutionContext, config: Dict[Text, Any], output: Any, input_hook_data: Dict, ) -> None: on_before() def on_before_node( self, node_name: Text, execution_context: ExecutionContext, config: Dict[Text, Any], received_inputs: Dict[Text, Any], ) -> Dict: on_after() return {} node = GraphNode.from_schema_node( "some_node", schema_node, default_model_storage, default_execution_context, hooks=[MyHook()], ) with pytest.raises(GraphComponentException): node()
def test_component_config(x: Optional[int], output: int, default_model_storage: ModelStorage): node = GraphNode( node_name="subtract", component_class=SubtractByX, constructor_name="create", component_config={"x": x} if x else {}, fn_name="subtract_x", inputs={"i": "input_node"}, eager=False, model_storage=default_model_storage, resource=None, execution_context=ExecutionContext(GraphSchema({}), "1"), ) result = node(("input_node", 5)) assert result == ("subtract", output)
def test_eager_and_not_eager(eager: bool, default_model_storage: ModelStorage): run_mock = Mock() create_mock = Mock() class SpyComponent(GraphComponent): @classmethod def create( cls, config: Dict, model_storage: ModelStorage, resource: Resource, execution_context: ExecutionContext, ) -> SpyComponent: create_mock() return cls() def run(self): return run_mock() node = GraphNode( node_name="spy_node", component_class=SpyComponent, constructor_name="create", component_config={}, fn_name="run", inputs={}, eager=eager, model_storage=default_model_storage, resource=None, execution_context=ExecutionContext(GraphSchema({}), "1"), ) if eager: assert create_mock.called else: assert not create_mock.called assert not run_mock.called node() assert create_mock.call_count == 1 assert run_mock.called
def test_calling_component(default_model_storage: ModelStorage): node = GraphNode( node_name="add_node", component_class=AddInputs, constructor_name="create", component_config={}, fn_name="add", inputs={ "i1": "input_node1", "i2": "input_node2" }, eager=False, model_storage=default_model_storage, resource=None, execution_context=ExecutionContext(GraphSchema({}), "1"), ) result = node(("input_node1", 3), ("input_node2", 4)) assert result == ("add_node", 7)
def test_execution_context(default_model_storage: ModelStorage): context = ExecutionContext(GraphSchema({}), "some_id") node = GraphNode( node_name="execution_context_aware", component_class=ExecutionContextAware, constructor_name="create", component_config={}, fn_name="get_execution_context", inputs={}, eager=False, model_storage=default_model_storage, resource=None, execution_context=context, ) context.model_id = "a_new_id" result = node()[1] assert result.model_id == "some_id" assert result.node_name == "execution_context_aware"
def test_cached_component_returns_value_from_cache( default_model_storage: ModelStorage): cached_output = CacheableText("Cache me!!") node = GraphNode( node_name="cached", component_class=PrecomputedValueProvider, constructor_name="create", component_config={"output": cached_output}, fn_name="get_value", inputs={}, eager=False, model_storage=default_model_storage, resource=None, execution_context=ExecutionContext(GraphSchema({}), "1"), ) node_name, returned_output = node() assert node_name == "cached" assert returned_output.text == "Cache me!!"
def test_fingerprint_component_miss(default_model_storage: ModelStorage, temp_cache: TrainingCache): component_config = {"x": 1} node = GraphNode( node_name="fingerprint_node", component_class=FingerprintComponent, constructor_name="create", component_config={ "config_of_replaced_component": component_config, "cache": temp_cache, "graph_component_class": PrecomputedValueProvider, }, fn_name="run", inputs={ "param_1": "parent_node_1", "param_2": "parent_node_2" }, eager=False, model_storage=default_model_storage, resource=None, execution_context=ExecutionContext(GraphSchema({}), "1"), ) node_name, returned_output = node( ("parent_node_1", FingerprintableText("input_1")), ("parent_node_2", FingerprintStatus(is_hit=True, output_fingerprint="input_2")), ) # As we didnt add anything to the cache, it cannot be a hit. assert node_name == "fingerprint_node" assert returned_output.is_hit is False assert returned_output.output_fingerprint is None assert returned_output.fingerprint() != returned_output.output_fingerprint assert returned_output.fingerprint() != returned_output.fingerprint()