Beispiel #1
0
async def train(
    domain_file: Union[Domain, Text],
    training_resource: Union[Text, "TrainingDataImporter"],
    output_path: Text,
    interpreter: Optional["NaturalLanguageInterpreter"] = None,
    endpoints: "AvailableEndpoints" = None,
    policy_config: Optional[Union[Text, Dict]] = None,
    exclusion_percentage: Optional[int] = None,
    additional_arguments: Optional[Dict] = None,
    model_to_finetune: Optional["Agent"] = None,
) -> "Agent":
    from rasa.core import config, utils
    from rasa.core.utils import AvailableEndpoints
    from rasa.core.agent import Agent

    if not endpoints:
        endpoints = AvailableEndpoints()

    if not additional_arguments:
        additional_arguments = {}

    policies = config.load(policy_config)

    agent = Agent(
        domain_file,
        generator=endpoints.nlg,
        action_endpoint=endpoints.action,
        interpreter=interpreter,
        policies=policies,
    )

    data_load_args, additional_arguments = utils.extract_args(
        additional_arguments,
        {
            "use_story_concatenation",
            "unique_last_num_states",
            "augmentation_factor",
            "remove_duplicates",
            "debug_plots",
        },
    )
    training_data = await agent.load_data(
        training_resource,
        exclusion_percentage=exclusion_percentage,
        **data_load_args)
    if model_to_finetune:
        agent.policy_ensemble = model_to_finetune.policy_ensemble
    agent.train(training_data, **additional_arguments)
    agent.persist(output_path)

    return agent
Beispiel #2
0
async def train(
    domain_file: Union[Domain, Text],
    stories_file: Text,
    output_path: Text,
    interpreter: Optional["NaturalLanguageInterpreter"] = None,
    endpoints: "AvailableEndpoints" = None,
    dump_stories: bool = False,
    policy_config: Text = None,
    exclusion_percentage: int = None,
    kwargs: Optional[Dict] = None,
):
    from rasa.core.agent import Agent
    from rasa.core import config, utils
    from rasa.core.utils import AvailableEndpoints

    if not endpoints:
        endpoints = AvailableEndpoints()

    if not kwargs:
        kwargs = {}

    policies = config.load(policy_config)

    agent = Agent(
        domain_file,
        generator=endpoints.nlg,
        action_endpoint=endpoints.action,
        interpreter=interpreter,
        policies=policies,
    )

    data_load_args, kwargs = utils.extract_args(
        kwargs,
        {
            "use_story_concatenation",
            "unique_last_num_states",
            "augmentation_factor",
            "remove_duplicates",
            "debug_plots",
        },
    )

    training_data = await agent.load_data(
        stories_file,
        exclusion_percentage=exclusion_percentage,
        **data_load_args)
    agent.train(training_data, **kwargs)
    agent.persist(output_path, dump_stories)

    return agent
Beispiel #3
0
async def train(
    domain_file: Union[Domain, Text],
    training_resource: Union[Text, "TrainingDataImporter"],
    output_path: Text,
    interpreters: Optional[Dict[Text, "NaturalLanguageInterpreter"]] = None,
    endpoints: "AvailableEndpoints" = None,
    dump_stories: bool = False,
    policy_config: Optional[Union[Text, Dict]] = None,
    exclusion_percentage: int = None,
    additional_arguments: Optional[Dict] = None,
):
    from rasa.core.agent import Agent
    from rasa.core import config, utils
    from rasa.core.utils import AvailableEndpoints

    if not endpoints:
        endpoints = AvailableEndpoints()

    if not additional_arguments:
        additional_arguments = {}

    policies = config.load(policy_config)

    agent = Agent(
        domain_file,
        generator=endpoints.nlg,
        action_endpoint=endpoints.action,
        interpreters=interpreters or {},  # fix to avoid model not ready error
        policies=policies,
    )

    data_load_args, additional_arguments = utils.extract_args(
        additional_arguments,
        {
            "use_story_concatenation",
            "unique_last_num_states",
            "augmentation_factor",
            "remove_duplicates",
            "debug_plots",
        },
    )
    training_data = await agent.load_data(
        training_resource,
        exclusion_percentage=exclusion_percentage,
        **data_load_args)
    agent.train(training_data, **additional_arguments)
    agent.persist(output_path, dump_stories)

    return agent