def test_training_hook_saves_to_cache( default_model_storage: ModelStorage, temp_cache: TrainingCache, ): # We need an execution context so the hook can determine the class of the graph # component execution_context = ExecutionContext( GraphSchema({ "hello": SchemaNode( needs={}, constructor_name="create", fn="run", config={}, uses=CacheableComponent, ) }), "1", ) node = GraphNode( node_name="hello", component_class=CacheableComponent, constructor_name="create", component_config={}, fn_name="run", inputs={"suffix": "input_node"}, eager=False, model_storage=default_model_storage, resource=None, execution_context=execution_context, hooks=[ TrainingHook( cache=temp_cache, model_storage=default_model_storage, pruned_schema=execution_context.graph_schema, ) ], ) node(("input_node", "Joe")) # This is the same key that the hook will generate fingerprint_key = fingerprinting.calculate_fingerprint_key( graph_component_class=CacheableComponent, config={"prefix": "Hello "}, inputs={"suffix": "Joe"}, ) output_fingerprint_key = temp_cache.get_cached_output_fingerprint( fingerprint_key) assert output_fingerprint_key cached_result = temp_cache.get_cached_result( output_fingerprint_key=output_fingerprint_key, model_storage=default_model_storage, node_name="hello", ) assert isinstance(cached_result, CacheableText) assert cached_result.text == "Hello Joe"
def train( self, train_schema: GraphSchema, predict_schema: GraphSchema, domain_path: Path, output_filename: Path, ) -> GraphRunner: """Trains and packages a model and returns the prediction graph runner. Args: train_schema: The train graph schema. predict_schema: The predict graph schema. domain_path: The path to the domain file. output_filename: The location to save the packaged model. Returns: A graph runner loaded with the predict schema. """ logger.debug("Starting training.") pruned_training_schema = self._fingerprint_and_prune(train_schema) hooks = [ TrainingHook(cache=self._cache, model_storage=self._model_storage) ] graph_runner = self._graph_runner_class.create( graph_schema=pruned_training_schema, model_storage=self._model_storage, execution_context=ExecutionContext( graph_schema=pruned_training_schema), hooks=hooks, ) logger.debug( "Running the pruned train graph with real node execution.") graph_runner.run() domain = Domain.from_path(domain_path) model_metadata = self._model_storage.create_model_package( output_filename, train_schema, predict_schema, domain) return self._graph_runner_class.create( graph_schema=predict_schema, model_storage=self._model_storage, execution_context=ExecutionContext( graph_schema=predict_schema, model_id=model_metadata.model_id), )
def test_training_hook_does_not_cache_cached_component( default_model_storage: ModelStorage, temp_cache: TrainingCache, ): # We need an execution context so the hook can determine the class of the graph # component execution_context = ExecutionContext( GraphSchema({ "hello": SchemaNode( needs={}, constructor_name="create", fn="run", config={}, uses=PrecomputedValueProvider, ) }), "1", ) node = GraphNode( node_name="hello", component_class=PrecomputedValueProvider, constructor_name="create", component_config={"output": CacheableText("hi")}, fn_name="get_value", inputs={}, eager=False, model_storage=default_model_storage, resource=None, execution_context=execution_context, hooks=[ TrainingHook( cache=temp_cache, model_storage=default_model_storage, pruned_schema=execution_context.graph_schema, ) ], ) node(("input_node", "Joe")) # This is the same key that the hook will generate fingerprint_key = fingerprinting.calculate_fingerprint_key( graph_component_class=PrecomputedValueProvider, config={"output": CacheableText("hi")}, inputs={}, ) # The hook should not cache the output of a PrecomputedValueProvider assert not temp_cache.get_cached_output_fingerprint(fingerprint_key)
def default_training_hook(temp_cache: TrainingCache, default_model_storage: ModelStorage) -> TrainingHook: return TrainingHook(cache=temp_cache, model_storage=default_model_storage)
def train( self, model_configuration: GraphModelConfiguration, importer: TrainingDataImporter, output_filename: Path, force_retraining: bool = False, is_finetuning: bool = False, ) -> ModelMetadata: """Trains and packages a model and returns the prediction graph runner. Args: model_configuration: The model configuration (schemas, language, etc.) importer: The importer which provides the training data for the training. output_filename: The location to save the packaged model. force_retraining: If `True` then the cache is skipped and all components are retrained. Returns: The metadata describing the trained model. """ logger.debug("Starting training.") # Retrieve the domain for the model metadata right at the start. # This avoids that something during the graph runs mutates it. domain = copy.deepcopy(importer.get_domain()) if force_retraining: logger.debug( "Skip fingerprint run as a full training of the model was enforced." ) pruned_training_schema = model_configuration.train_schema else: fingerprint_run_outputs = self.fingerprint( model_configuration.train_schema, importer=importer, is_finetuning=is_finetuning, ) pruned_training_schema = self._prune_schema( model_configuration.train_schema, fingerprint_run_outputs) hooks = [ LoggingHook(pruned_schema=pruned_training_schema), TrainingHook( cache=self._cache, model_storage=self._model_storage, pruned_schema=pruned_training_schema, ), ] graph_runner = self._graph_runner_class.create( graph_schema=pruned_training_schema, model_storage=self._model_storage, execution_context=ExecutionContext( graph_schema=model_configuration.train_schema, is_finetuning=is_finetuning, ), hooks=hooks, ) logger.debug( "Running the pruned train graph with real node execution.") graph_runner.run(inputs={PLACEHOLDER_IMPORTER: importer}) return self._model_storage.create_model_package( output_filename, model_configuration, domain)