Esempio n. 1
0
async def train(
    domain_file: Union[Domain, Text],
    training_resource: Union[Text, "TrainingDataImporter"],
    output_path: Text,
    interpreter: Optional["NaturalLanguageInterpreter"] = None,
    endpoints: "AvailableEndpoints" = None,
    policy_config: Optional[Union[Text, Dict]] = None,
    exclusion_percentage: Optional[int] = None,
    additional_arguments: Optional[Dict] = None,
    model_to_finetune: Optional["Agent"] = None,
) -> "Agent":
    from rasa.core import config, utils
    from rasa.core.utils import AvailableEndpoints
    from rasa.core.agent import Agent

    if not endpoints:
        endpoints = AvailableEndpoints()

    if not additional_arguments:
        additional_arguments = {}

    policies = config.load(policy_config)

    agent = Agent(
        domain_file,
        generator=endpoints.nlg,
        action_endpoint=endpoints.action,
        interpreter=interpreter,
        policies=policies,
    )

    data_load_args, additional_arguments = utils.extract_args(
        additional_arguments,
        {
            "use_story_concatenation",
            "unique_last_num_states",
            "augmentation_factor",
            "remove_duplicates",
            "debug_plots",
        },
    )
    training_data = await agent.load_data(
        training_resource,
        exclusion_percentage=exclusion_percentage,
        **data_load_args)
    if model_to_finetune:
        agent.policy_ensemble = model_to_finetune.policy_ensemble
    agent.train(training_data, **additional_arguments)
    agent.persist(output_path)

    return agent