コード例 #1
0
def test_training_hook_saves_to_cache(
    default_model_storage: ModelStorage,
    temp_cache: TrainingCache,
):
    # We need an execution context so the hook can determine the class of the graph
    # component
    execution_context = ExecutionContext(
        GraphSchema({
            "hello":
            SchemaNode(
                needs={},
                constructor_name="create",
                fn="run",
                config={},
                uses=CacheableComponent,
            )
        }),
        "1",
    )
    node = GraphNode(
        node_name="hello",
        component_class=CacheableComponent,
        constructor_name="create",
        component_config={},
        fn_name="run",
        inputs={"suffix": "input_node"},
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=execution_context,
        hooks=[
            TrainingHook(
                cache=temp_cache,
                model_storage=default_model_storage,
                pruned_schema=execution_context.graph_schema,
            )
        ],
    )

    node(("input_node", "Joe"))

    # This is the same key that the hook will generate
    fingerprint_key = fingerprinting.calculate_fingerprint_key(
        graph_component_class=CacheableComponent,
        config={"prefix": "Hello "},
        inputs={"suffix": "Joe"},
    )

    output_fingerprint_key = temp_cache.get_cached_output_fingerprint(
        fingerprint_key)
    assert output_fingerprint_key

    cached_result = temp_cache.get_cached_result(
        output_fingerprint_key=output_fingerprint_key,
        model_storage=default_model_storage,
        node_name="hello",
    )
    assert isinstance(cached_result, CacheableText)
    assert cached_result.text == "Hello Joe"
コード例 #2
0
ファイル: graph_trainer.py プロジェクト: praneethgb/rasa
    def train(
        self,
        train_schema: GraphSchema,
        predict_schema: GraphSchema,
        domain_path: Path,
        output_filename: Path,
    ) -> GraphRunner:
        """Trains and packages a model and returns the prediction graph runner.

        Args:
            train_schema: The train graph schema.
            predict_schema: The predict graph schema.
            domain_path: The path to the domain file.
            output_filename: The location to save the packaged model.

        Returns:
            A graph runner loaded with the predict schema.

        """
        logger.debug("Starting training.")

        pruned_training_schema = self._fingerprint_and_prune(train_schema)

        hooks = [
            TrainingHook(cache=self._cache, model_storage=self._model_storage)
        ]

        graph_runner = self._graph_runner_class.create(
            graph_schema=pruned_training_schema,
            model_storage=self._model_storage,
            execution_context=ExecutionContext(
                graph_schema=pruned_training_schema),
            hooks=hooks,
        )

        logger.debug(
            "Running the pruned train graph with real node execution.")

        graph_runner.run()

        domain = Domain.from_path(domain_path)
        model_metadata = self._model_storage.create_model_package(
            output_filename, train_schema, predict_schema, domain)

        return self._graph_runner_class.create(
            graph_schema=predict_schema,
            model_storage=self._model_storage,
            execution_context=ExecutionContext(
                graph_schema=predict_schema, model_id=model_metadata.model_id),
        )
コード例 #3
0
def test_training_hook_does_not_cache_cached_component(
    default_model_storage: ModelStorage,
    temp_cache: TrainingCache,
):
    # We need an execution context so the hook can determine the class of the graph
    # component
    execution_context = ExecutionContext(
        GraphSchema({
            "hello":
            SchemaNode(
                needs={},
                constructor_name="create",
                fn="run",
                config={},
                uses=PrecomputedValueProvider,
            )
        }),
        "1",
    )
    node = GraphNode(
        node_name="hello",
        component_class=PrecomputedValueProvider,
        constructor_name="create",
        component_config={"output": CacheableText("hi")},
        fn_name="get_value",
        inputs={},
        eager=False,
        model_storage=default_model_storage,
        resource=None,
        execution_context=execution_context,
        hooks=[
            TrainingHook(
                cache=temp_cache,
                model_storage=default_model_storage,
                pruned_schema=execution_context.graph_schema,
            )
        ],
    )

    node(("input_node", "Joe"))

    # This is the same key that the hook will generate
    fingerprint_key = fingerprinting.calculate_fingerprint_key(
        graph_component_class=PrecomputedValueProvider,
        config={"output": CacheableText("hi")},
        inputs={},
    )

    # The hook should not cache the output of a PrecomputedValueProvider
    assert not temp_cache.get_cached_output_fingerprint(fingerprint_key)
コード例 #4
0
ファイル: conftest.py プロジェクト: praneethgb/rasa
def default_training_hook(temp_cache: TrainingCache,
                          default_model_storage: ModelStorage) -> TrainingHook:
    return TrainingHook(cache=temp_cache, model_storage=default_model_storage)
コード例 #5
0
    def train(
        self,
        model_configuration: GraphModelConfiguration,
        importer: TrainingDataImporter,
        output_filename: Path,
        force_retraining: bool = False,
        is_finetuning: bool = False,
    ) -> ModelMetadata:
        """Trains and packages a model and returns the prediction graph runner.

        Args:
            model_configuration: The model configuration (schemas, language, etc.)
            importer: The importer which provides the training data for the training.
            output_filename: The location to save the packaged model.
            force_retraining: If `True` then the cache is skipped and all components
                are retrained.

        Returns:
            The metadata describing the trained model.
        """
        logger.debug("Starting training.")

        # Retrieve the domain for the model metadata right at the start.
        # This avoids that something during the graph runs mutates it.
        domain = copy.deepcopy(importer.get_domain())

        if force_retraining:
            logger.debug(
                "Skip fingerprint run as a full training of the model was enforced."
            )
            pruned_training_schema = model_configuration.train_schema
        else:
            fingerprint_run_outputs = self.fingerprint(
                model_configuration.train_schema,
                importer=importer,
                is_finetuning=is_finetuning,
            )
            pruned_training_schema = self._prune_schema(
                model_configuration.train_schema, fingerprint_run_outputs)

        hooks = [
            LoggingHook(pruned_schema=pruned_training_schema),
            TrainingHook(
                cache=self._cache,
                model_storage=self._model_storage,
                pruned_schema=pruned_training_schema,
            ),
        ]

        graph_runner = self._graph_runner_class.create(
            graph_schema=pruned_training_schema,
            model_storage=self._model_storage,
            execution_context=ExecutionContext(
                graph_schema=model_configuration.train_schema,
                is_finetuning=is_finetuning,
            ),
            hooks=hooks,
        )

        logger.debug(
            "Running the pruned train graph with real node execution.")

        graph_runner.run(inputs={PLACEHOLDER_IMPORTER: importer})

        return self._model_storage.create_model_package(
            output_filename, model_configuration, domain)