def test_importer_with_invalid_model_config(tmp_path: Path): invalid = {"version": "2.0", "policies": ["name"]} config_file = tmp_path / "config.yml" rasa.shared.utils.io.write_yaml(invalid, config_file) with pytest.raises(YamlValidationException): TrainingDataImporter.load_from_config(str(config_file))
def test_story_graph_provider_provide( default_model_storage: ModelStorage, default_execution_context: ExecutionContext, config: Dict[Text, Any], config_path: Text, domain_path: Text, stories_path: Text, ): component = StoryGraphProvider.create( { **StoryGraphProvider.get_default_config(), **config }, default_model_storage, Resource("xy"), default_execution_context, ) importer = TrainingDataImporter.load_from_config(config_path, domain_path, [stories_path]) story_graph_from_component = component.provide(importer) assert isinstance(story_graph_from_component, StoryGraph) story_graph = importer.get_stories(**config) assert story_graph.fingerprint() == story_graph_from_component.fingerprint( )
async def test_events_schema( monkeypatch: MonkeyPatch, default_agent: Agent, config_path: Text ): # this allows us to patch the printing part used in debug mode to collect the # reported events monkeypatch.setenv("RASA_TELEMETRY_DEBUG", "true") monkeypatch.setenv("RASA_TELEMETRY_ENABLED", "true") mock = Mock() monkeypatch.setattr(telemetry, "print_telemetry_event", mock) with open(TELEMETRY_EVENTS_JSON) as f: schemas = json.load(f)["events"] initial = asyncio.all_tasks() # Generate all known backend telemetry events, and then use events.json to # validate their schema. training_data = TrainingDataImporter.load_from_config(config_path) with telemetry.track_model_training(training_data, "rasa"): await asyncio.sleep(1) telemetry.track_telemetry_disabled() telemetry.track_data_split(0.5, "nlu") telemetry.track_validate_files(True) telemetry.track_data_convert("yaml", "nlu") telemetry.track_tracker_export(5, TrackerStore(domain=None), EventBroker()) telemetry.track_interactive_learning_start(True, False) telemetry.track_server_start([CmdlineInput()], None, None, 42, True) telemetry.track_project_init("tests/") telemetry.track_shell_started("nlu") telemetry.track_rasa_x_local() telemetry.track_visualization() telemetry.track_core_model_test(5, True, default_agent) telemetry.track_nlu_model_test(TrainingData()) pending = asyncio.all_tasks() - initial await asyncio.gather(*pending) assert mock.call_count == 15 for args, _ in mock.call_args_list: event = args[0] # `metrics_id` automatically gets added to all event but is # not part of the schema so we need to remove it before validation del event["properties"]["metrics_id"] jsonschema.validate( instance=event["properties"], schema=schemas[event["event"]] )
def test_domain_provider_provides_and_persists_domain( default_model_storage: ModelStorage, default_execution_context: ExecutionContext, config_path: Text, domain_path: Text, domain: Domain, ): resource = Resource("xy") component = DomainProvider.create( DomainProvider.get_default_config(), default_model_storage, resource, default_execution_context, ) assert isinstance(component, DomainProvider) importer = TrainingDataImporter.load_from_config(config_path, domain_path) training_domain = component.provide_train(importer) assert isinstance(training_domain, Domain) assert domain.fingerprint() == training_domain.fingerprint() with default_model_storage.read_from(resource) as d: match = list(d.glob("**/domain.yml")) assert len(match) == 1 assert match[0].is_file() assert domain.fingerprint() == Domain.from_path(match[0]).fingerprint() component_2 = DomainProvider.load( {}, default_model_storage, resource, default_execution_context ) inference_domain = component_2.provide_inference() assert isinstance(inference_domain, Domain) assert domain.fingerprint() == inference_domain.fingerprint()
def test_example_bot_training_data_raises_only_auto_fill_warning( config_file: Text, domain_file: Text, data_folder: Text, raise_slot_warning: bool, ): importer = TrainingDataImporter.load_from_config(config_file, domain_file, [data_folder]) if raise_slot_warning: with pytest.warns(UserWarning) as record: importer.get_nlu_data() importer.get_stories() assert len(record) == 2 assert all([ "Slot auto-fill has been removed in 3.0 and replaced with " "a new explicit mechanism to set slots." in r.message.args[0] for r in record ]) else: with pytest.warns(None) as record: importer.get_nlu_data() importer.get_stories() assert len(record) == 0
async def test_should_not_retrain_core(default_domain_path: Text, tmp_path: Path): # Don't use `default_stories_file` as checkpoints currently break fingerprinting story_file = tmp_path / "simple_story.yml" story_file.write_text(""" stories: - story: test_story steps: - intent: greet - action: utter_greet """) trained_model = await train_core_async(default_domain_path, DEFAULT_STACK_CONFIG, str(story_file), str(tmp_path)) importer = TrainingDataImporter.load_from_config( DEFAULT_STACK_CONFIG, default_domain_path, training_data_paths=[str(story_file)]) new_fingerprint = await model.model_fingerprint(importer) result = model.should_retrain(new_fingerprint, trained_model, tmp_path) assert not result.should_retrain_core()
def interactive(args: argparse.Namespace) -> None: _set_not_required_args(args) file_importer = TrainingDataImporter.load_from_config( args.config, args.domain, args.data) if args.model is None: loop = asyncio.get_event_loop() story_graph = loop.run_until_complete(file_importer.get_stories()) if not story_graph or story_graph.is_empty(): rasa.shared.utils.cli.print_error_and_exit( "Could not run interactive learning without either core data or a model containing core data." ) zipped_model = train.train_core( args) if args.core_only else train.train(args) if not zipped_model: rasa.shared.utils.cli.print_error_and_exit( "Could not train an initial model. Either pass paths " "to the relevant training files (`--data`, `--config`, `--domain`), " "or use 'rasa train' to train a model.") else: zipped_model = get_provided_model(args.model) if not (zipped_model and os.path.exists(zipped_model)): rasa.shared.utils.cli.print_error_and_exit( f"Interactive learning process cannot be started as no initial model was " f"found at path '{args.model}'. Use 'rasa train' to train a model." ) if not args.skip_visualization: logger.info(f"Loading visualization data from {args.data}.") perform_interactive_learning(args, zipped_model, file_importer)
def test_no_warnings_with_default_project(tmp_path: Path): rasa.utils.common.copy_directory(Path("rasa/cli/initial_project"), tmp_path) importer = TrainingDataImporter.load_from_config( config_path=str(tmp_path / "config.yml"), domain_path=str(tmp_path / "domain.yml"), training_data_paths=[str(tmp_path / "data")], ) config, _missing_keys, _configured_keys = DefaultV1Recipe.auto_configure( importer.get_config_file_for_auto_config(), importer.get_config(), TrainingType.END_TO_END, ) graph_config = DefaultV1Recipe().graph_config_for_recipe( config, cli_parameters={}, training_type=TrainingType.END_TO_END) validator = DefaultV1RecipeValidator(graph_config.train_schema) with pytest.warns( UserWarning, match="Slot auto-fill has been removed in 3.0") as records: validator.validate(importer) assert all([ warn.message.args[0].startswith("Slot auto-fill has been removed") for warn in records.list ])
async def test_events_schema(monkeypatch: MonkeyPatch): # this allows us to patch the printing part used in debug mode to collect the # reported events monkeypatch.setenv("RASA_TELEMETRY_DEBUG", "true") monkeypatch.setenv("RASA_TELEMETRY_ENABLED", "true") mock = Mock() monkeypatch.setattr(telemetry, "print_telemetry_event", mock) with open(TELEMETRY_EVENTS_JSON) as f: schemas = json.load(f)["events"] initial = asyncio.Task.all_tasks() # Generate all known backend telemetry events, and then use events.json to # validate their schema. training_data = TrainingDataImporter.load_from_config(DEFAULT_CONFIG_PATH) async with telemetry.track_model_training(training_data, "rasa"): await asyncio.sleep(1) await telemetry.track_telemetry_disabled() pending = asyncio.Task.all_tasks() - initial await asyncio.gather(*pending) assert mock.call_count == 3 for call in mock.call_args_list: event = call.args[0] # `metrics_id` automatically gets added to all event but is # not part of the schema so we need to remove it before validation del event["properties"]["metrics_id"] jsonschema.validate(instance=event["properties"], schema=schemas[event["event"]])
async def train_async( domain: Union[Domain, Text], config: Text, training_files: Optional[Union[Text, List[Text]]], output: Text = DEFAULT_MODELS_PATH, dry_run: bool = False, force_training: bool = False, fixed_model_name: Optional[Text] = None, persist_nlu_training_data: bool = False, core_additional_arguments: Optional[Dict] = None, nlu_additional_arguments: Optional[Dict] = None, ) -> TrainingResult: """Trains a Rasa model (Core and NLU). Args: domain: Path to the domain file. config: Path to the config for Core and NLU. training_files: Paths to the training data for Core and NLU. output_path: Output path. dry_run: If `True` then no training will be done, and the information about whether the training needs to be done will be printed. force_training: If `True` retrain model even if data has not changed. fixed_model_name: Name of model to be stored. persist_nlu_training_data: `True` if the NLU training data should be persisted with the model. core_additional_arguments: Additional training parameters for core training. nlu_additional_arguments: Additional training parameters forwarded to training method of each NLU component. Returns: An instance of `TrainingResult`. """ file_importer = TrainingDataImporter.load_from_config( config, domain, training_files) with ExitStack() as stack: train_path = stack.enter_context(TempDirectoryPath(tempfile.mkdtemp())) domain = await file_importer.get_domain() if domain.is_empty(): nlu_model = await handle_domain_if_not_exists( file_importer, output, fixed_model_name) return TrainingResult(model=nlu_model) return await _train_async_internal( file_importer, train_path, output, dry_run, force_training, fixed_model_name, persist_nlu_training_data, core_additional_arguments=core_additional_arguments, nlu_additional_arguments=nlu_additional_arguments, )
def test_load_from_config(tmpdir: Path): config_path = str(tmpdir / "config.yml") rasa.shared.utils.io.write_yaml( {"importers": [{"name": "MultiProjectImporter"}]}, config_path ) importer = TrainingDataImporter.load_from_config(config_path) assert isinstance(importer, E2EImporter) assert isinstance(importer.importer, ResponsesSyncImporter) assert isinstance(importer.importer._importer._importers[0], MultiProjectImporter)
async def test_example_bot_training_data_not_raises(config_file: Text, domain_file: Text, data_folder: Text): importer = TrainingDataImporter.load_from_config(config_file, domain_file, [data_folder]) with pytest.warns(None) as record: await importer.get_nlu_data() await importer.get_stories() assert not len(record)
def test_importer_with_invalid_model_config(tmp_path: Path): invalid = {"version": "2.0", "policies": ["name"]} config_file = tmp_path / "config.yml" rasa.shared.utils.io.write_yaml(invalid, config_file) with pytest.raises(YamlValidationException): importer = TrainingDataImporter.load_from_config(str(config_file)) DefaultV1Recipe.auto_configure( importer.get_config_file_for_auto_config(), importer.get_config(), TrainingType.END_TO_END, )
async def test_example_bot_training_on_initial_project(tmp_path: Path): # we need to test this one separately, as we can't test it in place # configuration suggestions would otherwise change the initial file scaffold.create_initial_project(str(tmp_path)) importer = TrainingDataImporter.load_from_config( str(tmp_path / "config.yml"), str(tmp_path / "domain.yml"), str(tmp_path / "data"), ) with pytest.warns(None) as record: await importer.get_nlu_data() await importer.get_stories() assert not len(record)
def test_example_bot_training_data_raises_only_auto_fill_warning( config_file: Text, domain_file: Text, data_folder: Text): importer = TrainingDataImporter.load_from_config(config_file, domain_file, [data_folder]) with pytest.warns(UserWarning) as record: importer.get_nlu_data() importer.get_stories() # two for slot auto-fill removal assert len(record) == 2 assert ("Slot auto-fill has been removed in 3.0 and replaced with " "a new explicit mechanism to set slots." in record[0].message.args[0]) assert record[0].message.args[0] == record[1].message.args[0]
def test_example_bot_training_on_initial_project(tmp_path: Path): # we need to test this one separately, as we can't test it in place # configuration suggestions would otherwise change the initial file scaffold.create_initial_project(str(tmp_path)) importer = TrainingDataImporter.load_from_config( str(tmp_path / "config.yml"), str(tmp_path / "domain.yml"), str(tmp_path / "data"), ) with pytest.warns(UserWarning) as record: importer.get_nlu_data() importer.get_stories() # two for slot auto-fill removal assert len(record) == 2 assert ("Slot auto-fill has been removed in 3.0 and replaced with " "a new explicit mechanism to set slots." in record[0].message.args[0]) assert record[0].message.args[0] == record[1].message.args[0]
def test_should_not_retrain_core(domain_path: Text, tmp_path: Path, stack_config_path: Text): # Don't use `stories_path` as checkpoints currently break fingerprinting story_file = tmp_path / "simple_story.yml" story_file.write_text(""" stories: - story: test_story steps: - intent: greet - action: utter_greet """) trained_model = train_core(domain_path, stack_config_path, str(story_file), str(tmp_path)) importer = TrainingDataImporter.load_from_config( stack_config_path, domain_path, training_data_paths=[str(story_file)]) new_fingerprint = model.model_fingerprint(importer) result = model.should_retrain(new_fingerprint, trained_model, tmp_path) assert not result.should_retrain_core()
def test_no_warnings_with_default_project(tmp_path: Path): rasa.utils.common.copy_directory(Path("rasa/cli/initial_project"), tmp_path) importer = TrainingDataImporter.load_from_config( config_path=str(tmp_path / "config.yml"), domain_path=str(tmp_path / "domain.yml"), training_data_paths=[str(tmp_path / "data")], ) config, _missing_keys, _configured_keys = DefaultV1Recipe.auto_configure( importer.get_config_file_for_auto_config(), importer.get_config(), TrainingType.END_TO_END, ) graph_config = DefaultV1Recipe().graph_config_for_recipe( config, cli_parameters={}, training_type=TrainingType.END_TO_END) validator = DefaultV1RecipeValidator(graph_config.train_schema) with pytest.warns(None) as records: validator.validate(importer) assert len(records) == 0
async def test_formbot_example(): sys.path.append("examples/formbot/") project = Path("examples/formbot/") config = str(project / "config.yml") domain = str(project / "domain.yml") training_dir = project / "data" training_files = [ str(training_dir / "rules.yml"), str(training_dir / "stories.yml"), ] importer = TrainingDataImporter.load_from_config(config, domain, training_files) endpoint = EndpointConfig("https://example.com/webhooks/actions") endpoints = AvailableEndpoints(action=endpoint) agent = await train( domain, importer, str(project / "models" / "dialogue"), endpoints=endpoints, policy_config="examples/formbot/config.yml", ) async def mock_form_happy_path(input_text, output_text, slot=None): if slot: form = "restaurant_form" template = f"utter_ask_{slot}" else: form = None template = "utter_submit" response = { "events": [ {"event": "form", "name": form, "timestamp": None}, { "event": "slot", "timestamp": None, "name": "requested_slot", "value": slot, }, ], "responses": [{"template": template}], } with aioresponses() as mocked: mocked.post( "https://example.com/webhooks/actions", payload=response, repeat=True ) responses = await agent.handle_text(input_text) assert responses[0]["text"] == output_text async def mock_form_unhappy_path(input_text, output_text, slot): response_error = { "error": f"Failed to extract slot {slot} with action restaurant_form", "action_name": "restaurant_form", } with aioresponses() as mocked: # noinspection PyTypeChecker mocked.post( "https://example.com/webhooks/actions", repeat=True, exception=ClientResponseError(400, "", json.dumps(response_error)), ) responses = await agent.handle_text(input_text) assert responses[0]["text"] == output_text await mock_form_happy_path("/request_restaurant", "What cuisine?", slot="cuisine") await mock_form_unhappy_path("/chitchat", "chitchat", slot="cuisine") await mock_form_happy_path( '/inform{"cuisine": "mexican"}', "How many people?", slot="num_people" ) await mock_form_happy_path( '/inform{"number": "2"}', "Do you want to sit outside?", slot="outdoor_seating" ) await mock_form_happy_path( "/affirm", "Please provide additional preferences", slot="preferences" ) responses = await agent.handle_text("/restart") assert responses[0]["text"] == "restarted" responses = await agent.handle_text("/greet") assert ( responses[0]["text"] == "Hello! I am restaurant search assistant! How can I help?" ) await mock_form_happy_path("/request_restaurant", "What cuisine?", slot="cuisine") await mock_form_happy_path( '/inform{"cuisine": "mexican"}', "How many people?", slot="num_people" ) await mock_form_happy_path( '/inform{"number": "2"}', "Do you want to sit outside?", slot="outdoor_seating" ) await mock_form_unhappy_path( "/stop", "Do you want to continue?", slot="outdoor_seating" ) await mock_form_happy_path( "/affirm", "Do you want to sit outside?", slot="outdoor_seating" ) await mock_form_happy_path( "/affirm", "Please provide additional preferences", slot="preferences" ) await mock_form_happy_path( "/deny", "Please give your feedback on your experience so far", slot="feedback" ) await mock_form_happy_path('/inform{"feedback": "great"}', "All done!") responses = await agent.handle_text("/thankyou") assert responses[0]["text"] == "You are welcome :)"
async def train_async( domain: Union[Domain, Text], config: Text, training_files: Optional[Union[Text, List[Text]]], output: Text = DEFAULT_MODELS_PATH, dry_run: bool = False, force_training: bool = False, fixed_model_name: Optional[Text] = None, persist_nlu_training_data: bool = False, core_additional_arguments: Optional[Dict] = None, nlu_additional_arguments: Optional[Dict] = None, model_to_finetune: Optional[Text] = None, finetuning_epoch_fraction: float = 1.0, ) -> TrainingResult: """Trains a Rasa model (Core and NLU). Args: domain: Path to the domain file. config: Path to the config for Core and NLU. training_files: Paths to the training data for Core and NLU. output_path: Output path. dry_run: If `True` then no training will be done, and the information about whether the training needs to be done will be printed. force_training: If `True` retrain model even if data has not changed. fixed_model_name: Name of model to be stored. persist_nlu_training_data: `True` if the NLU training data should be persisted with the model. core_additional_arguments: Additional training parameters for core training. nlu_additional_arguments: Additional training parameters forwarded to training method of each NLU component. model_to_finetune: Optional path to a model which should be finetuned or a directory in case the latest trained model should be used. finetuning_epoch_fraction: The fraction currently specified training epochs in the model configuration which should be used for finetuning. Returns: An instance of `TrainingResult`. """ file_importer = TrainingDataImporter.load_from_config( config, domain, training_files) with TempDirectoryPath(tempfile.mkdtemp()) as train_path: domain = await file_importer.get_domain() if domain.is_empty(): nlu_model = await handle_domain_if_not_exists( file_importer, output, fixed_model_name) return TrainingResult(model=nlu_model) return await _train_async_internal( file_importer, train_path, output, dry_run, force_training, fixed_model_name, persist_nlu_training_data, core_additional_arguments=core_additional_arguments, nlu_additional_arguments=nlu_additional_arguments, model_to_finetune=model_to_finetune, finetuning_epoch_fraction=finetuning_epoch_fraction, )
def test_nlu_training_data_provider( default_model_storage: ModelStorage, default_execution_context: ExecutionContext, config_path: Text, nlu_data_path: Text, ): # create a resource and an importer resource = Resource("xy") importer = TrainingDataImporter.load_from_config( config_path=config_path, training_data_paths=[nlu_data_path]) # check the default configuration is as expected config_1 = NLUTrainingDataProvider.get_default_config() assert config_1["language"] is None assert config_1["persist"] is False # create a provider with persist == True provider_1 = NLUTrainingDataProvider.create( { "language": "en", "persist": True }, default_model_storage, resource, default_execution_context, ) assert isinstance(provider_1, NLUTrainingDataProvider) # check the data provided is as expected data_0 = provider_1.provide(importer) data_1 = importer.get_nlu_data(language="en") assert data_0.fingerprint() == data_1.fingerprint() # check the data was persisted with default_model_storage.read_from(resource) as resource_directory: data_file = os.path.join(str(resource_directory), DEFAULT_TRAINING_DATA_OUTPUT_PATH) data = load_data(resource_name=data_file, language="en") assert os.path.isfile(data_file) assert isinstance(data, TrainingData) # delete the persisted data os.remove(data_file) assert not os.path.isfile(data_file) # create a provider with persist == False provider_2 = NLUTrainingDataProvider.create( { "language": "en", "persist": False }, default_model_storage, resource, default_execution_context, ) provider_2.provide(importer) # check the data was not persisted with default_model_storage.read_from(resource) as resource_directory: data_file = os.path.join(str(resource_directory), DEFAULT_TRAINING_DATA_OUTPUT_PATH) assert not os.path.isfile(data_file)
def _create_importer_from_config( config: Dict[Text, Any], path: Path, config_file_name: Text) -> TrainingDataImporter: config1_path = path / config_file_name rasa.shared.utils.io.write_yaml(config, config1_path, True) return TrainingDataImporter.load_from_config(str(config1_path))
async def test_fingerprinting_changing_config_epochs(project: Text, tmp_path): config1 = { "language": "en", "pipeline": [ { "name": "WhitespaceTokenizer" }, { "name": "RegexFeaturizer" }, { "name": "LexicalSyntacticFeaturizer" }, { "name": "CountVectorsFeaturizer" }, { "name": "CountVectorsFeaturizer", "analyzer": "char_wb", "min_ngram": 1, "max_ngram": 4, }, { "name": "DIETClassifier", "epochs": 100 }, { "name": "EntitySynonymMapper" }, { "name": "ResponseSelector", "epochs": 100 }, { "name": "FallbackClassifier", "threshold": 0.3, "ambiguity_threshold": 0.1, }, ], "policies": [ { "name": "MemoizationPolicy" }, { "name": "TEDPolicy", "max_history": 5, "epochs": 100 }, { "name": "RulePolicy" }, ], } config1_path = tmp_path / "config1.yml" rasa.shared.utils.io.write_yaml(config1, config1_path, True) importer = TrainingDataImporter.load_from_config(str(config1_path)) old_fingerprint = await model_fingerprint(importer) config2 = { "language": "en", "pipeline": [ { "name": "WhitespaceTokenizer" }, { "name": "RegexFeaturizer" }, { "name": "LexicalSyntacticFeaturizer" }, { "name": "CountVectorsFeaturizer" }, { "name": "CountVectorsFeaturizer", "analyzer": "char_wb", "min_ngram": 1, "max_ngram": 4, }, { "name": "DIETClassifier", "epochs": 50 }, { "name": "EntitySynonymMapper" }, { "name": "ResponseSelector", "epochs": 50 }, { "name": "FallbackClassifier", "threshold": 0.3, "ambiguity_threshold": 0.1, }, ], "policies": [ { "name": "MemoizationPolicy" }, { "name": "TEDPolicy", "max_history": 5, "epochs": 50 }, { "name": "RulePolicy" }, ], } config2_path = tmp_path / "config2.yml" rasa.shared.utils.io.write_yaml(config2, config2_path, True) importer = TrainingDataImporter.load_from_config(str(config2_path)) new_fingerprint = await model_fingerprint(importer) assert (old_fingerprint[FINGERPRINT_CONFIG_WITHOUT_EPOCHS_KEY] == new_fingerprint[FINGERPRINT_CONFIG_WITHOUT_EPOCHS_KEY]) assert (old_fingerprint[FINGERPRINT_CONFIG_CORE_KEY] != new_fingerprint[FINGERPRINT_CONFIG_CORE_KEY]) assert (old_fingerprint[FINGERPRINT_CONFIG_NLU_KEY] != new_fingerprint[FINGERPRINT_CONFIG_NLU_KEY]) config3 = { "language": "en", "pipeline": [ { "name": "WhitespaceTokenizer" }, ], "policies": [ { "name": "MemoizationPolicy" }, { "name": "TEDPolicy", "max_history": 5, "epochs": 50 }, { "name": "RulePolicy" }, ], } config3_path = tmp_path / "config3.yml" rasa.shared.utils.io.write_yaml(config3, config3_path, True) importer = TrainingDataImporter.load_from_config(str(config3_path)) new_fingerprint = await model_fingerprint(importer) assert (old_fingerprint[FINGERPRINT_CONFIG_WITHOUT_EPOCHS_KEY] != new_fingerprint[FINGERPRINT_CONFIG_WITHOUT_EPOCHS_KEY])
def train( domain: Text, config: Text, training_files: Optional[Union[Text, List[Text]]], output: Text = rasa.shared.constants.DEFAULT_MODELS_PATH, dry_run: bool = False, force_training: bool = False, fixed_model_name: Optional[Text] = None, persist_nlu_training_data: bool = False, core_additional_arguments: Optional[Dict] = None, nlu_additional_arguments: Optional[Dict] = None, model_to_finetune: Optional[Text] = None, finetuning_epoch_fraction: float = 1.0, ) -> TrainingResult: """Trains a Rasa model (Core and NLU). Args: domain: Path to the domain file. config: Path to the config file. training_files: List of paths to training data files. output: Output directory for the trained model. dry_run: If `True` then no training will be done, and the information about whether the training needs to be done will be printed. force_training: If `True` retrain model even if data has not changed. fixed_model_name: Name of model to be stored. persist_nlu_training_data: `True` if the NLU training data should be persisted with the model. core_additional_arguments: Additional training parameters for core training. nlu_additional_arguments: Additional training parameters forwarded to training method of each NLU component. model_to_finetune: Optional path to a model which should be finetuned or a directory in case the latest trained model should be used. finetuning_epoch_fraction: The fraction currently specified training epochs in the model configuration which should be used for finetuning. Returns: An instance of `TrainingResult`. """ file_importer = TrainingDataImporter.load_from_config( config, domain, training_files) stories = file_importer.get_stories() nlu_data = file_importer.get_nlu_data() training_type = TrainingType.BOTH if nlu_data.has_e2e_examples(): rasa.shared.utils.common.mark_as_experimental_feature( "end-to-end training") training_type = TrainingType.END_TO_END if stories.is_empty() and nlu_data.contains_no_pure_nlu_data(): rasa.shared.utils.cli.print_error( "No training data given. Please provide stories and NLU data in " "order to train a Rasa model using the '--data' argument.") return TrainingResult(code=1) domain = file_importer.get_domain() if domain.is_empty(): rasa.shared.utils.cli.print_warning( "Core training was skipped because no valid domain file was found. " "Only an NLU-model was created. Please specify a valid domain using " "the '--domain' argument or check if the provided domain file exists." ) training_type = TrainingType.NLU elif stories.is_empty(): rasa.shared.utils.cli.print_warning( "No stories present. Just a Rasa NLU model will be trained.") training_type = TrainingType.NLU # We will train nlu if there are any nlu example, including from e2e stories. elif nlu_data.contains_no_pure_nlu_data( ) and not nlu_data.has_e2e_examples(): rasa.shared.utils.cli.print_warning( "No NLU data present. Just a Rasa Core model will be trained.") training_type = TrainingType.CORE with telemetry.track_model_training(file_importer, model_type="rasa"): return _train_graph( file_importer, training_type=training_type, output_path=output, fixed_model_name=fixed_model_name, model_to_finetune=model_to_finetune, force_full_training=force_training, persist_nlu_training_data=persist_nlu_training_data, finetuning_epoch_fraction=finetuning_epoch_fraction, dry_run=dry_run, **(core_additional_arguments or {}), **(nlu_additional_arguments or {}), )