コード例 #1
0
def test_tracker_generator_parameter_interpolation():
    config = rasa.shared.utils.io.read_yaml("""
    version: '2.0'

    policies:
    - name: RulePolicy
    """)

    augmentation = 0
    debug_plots = True

    recipe = Recipe.recipe_for_name(DefaultV1Recipe.name)
    model_config = recipe.graph_config_for_recipe(
        config,
        {
            "augmentation_factor": augmentation,
            "debug_plots": debug_plots
        },
    )

    node = model_config.train_schema.nodes["training_tracker_provider"]

    assert node.config == {
        "augmentation_factor": augmentation,
        "debug_plots": debug_plots,
    }
コード例 #2
0
def test_epoch_fraction_cli_param_unspecified():
    # TODO: enhance testing of cli instead of imitating expected parsed input
    expected_schema_as_dict = rasa.shared.utils.io.read_yaml_file(
        "data/graph_schemas/default_config_finetune_epoch_fraction_schema.yml")
    expected_train_schema = GraphSchema.from_dict(expected_schema_as_dict)

    # modify the expected schema
    for schema_node in expected_train_schema.nodes.values():
        if "finetuning_epoch_fraction" in schema_node.config:
            schema_node.config["finetuning_epoch_fraction"] = 1.0
            if "epochs" in schema_node.config:
                schema_node.config["epochs"] *= 2

    config = rasa.shared.utils.io.read_yaml_file(
        "rasa/engine/recipes/config_files/default_config.yml")

    recipe = Recipe.recipe_for_name(DefaultV1Recipe.name)
    model_config = recipe.graph_config_for_recipe(
        config, {"finetuning_epoch_fraction": None}, is_finetuning=True)

    train_schema = model_config.train_schema
    for node_name, node in expected_train_schema.nodes.items():
        assert train_schema.nodes[node_name] == node

    assert train_schema == expected_train_schema
コード例 #3
0
ファイル: test_graph_recipe.py プロジェクト: FGA-GCES/rasa
def test_generate_graphs(
    config_path: Text,
    expected_train_schema_path: Text,
    expected_predict_schema_path: Text,
    training_type: TrainingType,
):
    expected_schema_as_dict = rasa.shared.utils.io.read_yaml_file(
        expected_train_schema_path)
    expected_train_schema = GraphSchema.from_dict(expected_schema_as_dict)

    expected_schema_as_dict = rasa.shared.utils.io.read_yaml_file(
        expected_predict_schema_path)
    expected_predict_schema = GraphSchema.from_dict(expected_schema_as_dict)

    config = rasa.shared.utils.io.read_yaml_file(config_path)

    recipe = Recipe.recipe_for_name(GraphV1Recipe.name)
    model_config = recipe.graph_config_for_recipe(config, {},
                                                  training_type=training_type)

    assert model_config.train_schema == expected_train_schema
    assert model_config.predict_schema == expected_predict_schema

    if training_type == TrainingType.NLU:
        core_target = None
    else:
        core_target = config.get("core_target", "select_prediction")

    assert model_config.core_target == core_target
    assert model_config.nlu_target == config.get("nlu_target",
                                                 "run_RegexMessageHandler")

    rasa.engine.validation.validate(model_config)
コード例 #4
0
def test_language_returning():
    config = rasa.shared.utils.io.read_yaml("""
    language: "xy"
    version: '2.0'

    policies:
    - name: RulePolicy
    """)

    recipe = Recipe.recipe_for_name(DefaultV1Recipe.name)
    model_config = recipe.graph_config_for_recipe(config, {})

    assert model_config.language == "xy"
コード例 #5
0
def test_nlu_config_doesnt_get_overridden(cli_parameters: Dict[Text, Any],
                                          check_node: Text,
                                          expected_config: Dict[Text, Any]):
    config = rasa.shared.utils.io.read_yaml_file(
        "data/test_config/config_pretrained_embeddings_mitie_diet.yml")
    recipe = Recipe.recipe_for_name(DefaultV1Recipe.name)
    model_config = recipe.graph_config_for_recipe(
        config,
        cli_parameters,
        training_type=TrainingType.BOTH,
        is_finetuning=True)

    train_schema = model_config.train_schema
    mitie_node = train_schema.nodes.get(check_node)
    assert mitie_node.config == expected_config
コード例 #6
0
def test_nlu_training_data_persistence():
    config = rasa.shared.utils.io.read_yaml("""
    version: '2.0'

    pipeline:
    - name: KeywordIntentClassifier
    """)

    recipe = Recipe.recipe_for_name(DefaultV1Recipe.name)
    model_config = recipe.graph_config_for_recipe(
        config, {"persist_nlu_training_data": True})

    node = model_config.train_schema.nodes["nlu_training_data_provider"]

    assert node.config == {"language": None, "persist": True}
    assert node.is_target
コード例 #7
0
def test_epoch_fraction_cli_param():
    expected_schema_as_dict = rasa.shared.utils.io.read_yaml_file(
        "data/graph_schemas/default_config_finetune_epoch_fraction_schema.yml")
    expected_train_schema = GraphSchema.from_dict(expected_schema_as_dict)

    config = rasa.shared.utils.io.read_yaml_file(
        "rasa/shared/importers/default_config.yml")

    recipe = Recipe.recipe_for_name(DefaultV1Recipe.name)
    model_config = recipe.graph_config_for_recipe(
        config, {"finetuning_epoch_fraction": 0.5}, is_finetuning=True)

    train_schema = model_config.train_schema
    for node_name, node in expected_train_schema.nodes.items():
        assert train_schema.nodes[node_name] == node

    assert train_schema == expected_train_schema
コード例 #8
0
ファイル: test_graph_recipe.py プロジェクト: FGA-GCES/rasa
def test_language_returning():
    config = rasa.shared.utils.io.read_yaml("""
    language: "xy"
    recipe: graph.v1
    core_target: doesnt_validate_or_run
    nlu_target: doesnt_validate_or_run

    train_schema:
      nodes: {}
    predict_schema:
      nodes: {}
    """)

    recipe = Recipe.recipe_for_name(GraphV1Recipe.name)
    model_config = recipe.graph_config_for_recipe(config, {})

    assert model_config.language == "xy"
コード例 #9
0
def test_generate_graphs(
    config_path: Text,
    expected_train_schema_path: Text,
    expected_predict_schema_path: Text,
    training_type: TrainingType,
    is_finetuning: bool,
):
    expected_schema_as_dict = rasa.shared.utils.io.read_yaml_file(
        expected_train_schema_path)
    expected_train_schema = GraphSchema.from_dict(expected_schema_as_dict)

    expected_schema_as_dict = rasa.shared.utils.io.read_yaml_file(
        expected_predict_schema_path)
    expected_predict_schema = GraphSchema.from_dict(expected_schema_as_dict)

    config = rasa.shared.utils.io.read_yaml_file(config_path)

    recipe = Recipe.recipe_for_name(DefaultV1Recipe.name)
    model_config = recipe.graph_config_for_recipe(config, {},
                                                  training_type=training_type,
                                                  is_finetuning=is_finetuning)

    train_schema = model_config.train_schema
    for node_name, node in expected_train_schema.nodes.items():
        assert train_schema.nodes[node_name] == node

    assert train_schema == expected_train_schema

    default_v1_validator = DefaultV1RecipeValidator(train_schema)
    importer = RasaFileImporter()
    # does not raise
    default_v1_validator.validate(importer)

    predict_schema = model_config.predict_schema
    for node_name, node in expected_predict_schema.nodes.items():
        assert predict_schema.nodes[node_name] == node

    assert predict_schema == expected_predict_schema

    rasa.engine.validation.validate(model_config)
コード例 #10
0
def test_num_threads_interpolation():
    expected_schema_as_dict = rasa.shared.utils.io.read_yaml_file(
        "data/graph_schemas/config_pretrained_embeddings_mitie_train_schema.yml"
    )
    expected_train_schema = GraphSchema.from_dict(expected_schema_as_dict)

    expected_schema_as_dict = rasa.shared.utils.io.read_yaml_file(
        "data/graph_schemas/config_pretrained_embeddings_mitie_predict_schema.yml"
    )
    expected_predict_schema = GraphSchema.from_dict(expected_schema_as_dict)

    for node_name, node in expected_train_schema.nodes.items():
        if issubclass(
                node.uses,
            (
                SklearnIntentClassifier,
                MitieEntityExtractor,
                MitieIntentClassifier,
            ),
        ) and node_name.startswith("train_"):
            node.config["num_threads"] = 20

    config = rasa.shared.utils.io.read_yaml_file(
        "data/test_config/config_pretrained_embeddings_mitie.yml")

    recipe = Recipe.recipe_for_name(DefaultV1Recipe.name)
    model_config = recipe.graph_config_for_recipe(config, {"num_threads": 20})

    train_schema = model_config.train_schema
    for node_name, node in expected_train_schema.nodes.items():
        assert train_schema.nodes[node_name] == node

    assert train_schema == expected_train_schema

    predict_schema = model_config.predict_schema
    for node_name, node in expected_predict_schema.nodes.items():
        assert predict_schema.nodes[node_name] == node

    assert predict_schema == expected_predict_schema
コード例 #11
0
def test_recipe_for_name():
    recipe = Recipe.recipe_for_name("default.v1")
    assert isinstance(recipe, DefaultV1Recipe)
コード例 #12
0
ファイル: test_recipe.py プロジェクト: zoovu/rasa
def test_invalid_recipe():
    with pytest.raises(InvalidRecipeException):
        Recipe.recipe_for_name("dalksldkas")
コード例 #13
0
ファイル: test_recipe.py プロジェクト: zoovu/rasa
def test_recipe_is_none():
    with pytest.warns(FutureWarning):
        recipe = Recipe.recipe_for_name(None)

    assert isinstance(recipe, DefaultV1Recipe)
コード例 #14
0
ファイル: model_training.py プロジェクト: FGA-GCES/rasa
def _train_graph(
    file_importer: TrainingDataImporter,
    training_type: TrainingType,
    output_path: Text,
    fixed_model_name: Text,
    model_to_finetune: Optional[Text] = None,
    force_full_training: bool = False,
    dry_run: bool = False,
    **kwargs: Any,
) -> TrainingResult:
    if model_to_finetune:
        model_to_finetune = rasa.model.get_model_for_finetuning(
            model_to_finetune)
        if not model_to_finetune:
            rasa.shared.utils.cli.print_error_and_exit(
                f"No model for finetuning found. Please make sure to either "
                f"specify a path to a previous model or to have a finetunable "
                f"model within the directory '{output_path}'.")

        rasa.shared.utils.common.mark_as_experimental_feature(
            "Incremental Training feature")

    is_finetuning = model_to_finetune is not None

    config = file_importer.get_config()
    recipe = Recipe.recipe_for_name(config.get("recipe"))
    config, _missing_keys, _configured_keys = recipe.auto_configure(
        file_importer.get_config_file_for_auto_config(),
        config,
        training_type,
    )
    model_configuration = recipe.graph_config_for_recipe(
        config,
        kwargs,
        training_type=training_type,
        is_finetuning=is_finetuning,
    )
    rasa.engine.validation.validate(model_configuration)

    with tempfile.TemporaryDirectory() as temp_model_dir:
        model_storage = _create_model_storage(is_finetuning, model_to_finetune,
                                              Path(temp_model_dir))
        cache = LocalTrainingCache()
        trainer = GraphTrainer(model_storage, cache, DaskGraphRunner)

        if dry_run:
            fingerprint_status = trainer.fingerprint(
                model_configuration.train_schema, file_importer)
            return _dry_run_result(fingerprint_status, force_full_training)

        model_name = _determine_model_name(fixed_model_name, training_type)
        full_model_path = Path(output_path, model_name)

        with telemetry.track_model_training(
                file_importer, model_type=training_type.model_type):
            trainer.train(
                model_configuration,
                file_importer,
                full_model_path,
                force_retraining=force_full_training,
                is_finetuning=is_finetuning,
            )
            rasa.shared.utils.cli.print_success(
                f"Your Rasa model is trained and saved at '{full_model_path}'."
            )

        return TrainingResult(str(full_model_path), 0)
コード例 #15
0
ファイル: test_graph_recipe.py プロジェクト: FGA-GCES/rasa
def test_recipe_for_name():
    recipe = Recipe.recipe_for_name("graph.v1")
    assert isinstance(recipe, GraphV1Recipe)