Пример #1
0
def train_core(args: argparse.Namespace,
               train_path: Optional[Text] = None) -> Optional[Text]:
    from rasa.train import train_core
    import asyncio

    loop = asyncio.get_event_loop()
    output = train_path or args.out

    args.domain = get_validated_path(args.domain, "domain", DEFAULT_DOMAIN_PATH)
    stories = get_validated_path(args.stories, "stories", DEFAULT_DATA_PATH)

    _train_path = train_path or tempfile.mkdtemp()

    # Policies might be a list for the compare training. Do normal training
    # if only list item was passed.
    if not isinstance(args.config, list) or len(args.config) == 1:
        if isinstance(args.config, list):
            args.config = args.config[0]

        config = get_validated_path(args.config, "config", DEFAULT_CONFIG_PATH)

        return train_core(args.domain, config, stories, output, train_path)
    else:
        from rasa.core.train import do_compare_training
        loop.run_until_complete(do_compare_training(args, stories, None))
        return None
Пример #2
0
def run_core_training(args: argparse.Namespace,
                      train_path: Optional[Text] = None) -> Optional[Text]:
    """Trains a Rasa Core model only.

    Args:
        args: Command-line arguments to configure training.
        train_path: Path where trained model but not unzipped model should be stored.

    Returns:
        Path to a trained model or `None` if training was not successful.
    """
    from rasa.model_training import train_core

    output = train_path or args.out

    args.domain = rasa.cli.utils.get_validated_path(args.domain,
                                                    "domain",
                                                    DEFAULT_DOMAIN_PATH,
                                                    none_is_valid=True)
    story_file = rasa.cli.utils.get_validated_path(args.stories,
                                                   "stories",
                                                   DEFAULT_DATA_PATH,
                                                   none_is_valid=True)
    additional_arguments = extract_core_additional_arguments(args)

    # Policies might be a list for the compare training. Do normal training
    # if only list item was passed.
    if not isinstance(args.config, list) or len(args.config) == 1:
        if isinstance(args.config, list):
            args.config = args.config[0]

        config = _get_valid_config(args.config, CONFIG_MANDATORY_KEYS_CORE)

        return train_core(
            domain=args.domain,
            config=config,
            stories=story_file,
            output=output,
            train_path=train_path,
            fixed_model_name=args.fixed_model_name,
            additional_arguments=additional_arguments,
            model_to_finetune=_model_for_finetuning(args),
            finetuning_epoch_fraction=args.epoch_fraction,
        )
    else:
        do_compare_training(args, story_file, additional_arguments)
        return None
Пример #3
0
def train_core(args: argparse.Namespace,
               train_path: Optional[Text] = None) -> Optional[Text]:
    """Trains a Core model.

    Args:
        args: Namespace arguments.
        train_path: Directory where models should be stored.

    Returns:
        Path to a trained model or `None` if training was not successful.
    """
    from rasa.train import train_core

    output = train_path or args.out

    args.domain = rasa.cli.utils.get_validated_path(args.domain,
                                                    "domain",
                                                    DEFAULT_DOMAIN_PATH,
                                                    none_is_valid=True)
    story_file = rasa.cli.utils.get_validated_path(args.stories,
                                                   "stories",
                                                   DEFAULT_DATA_PATH,
                                                   none_is_valid=True)
    additional_arguments = extract_core_additional_arguments(args)

    # Policies might be a list for the compare training. Do normal training
    # if only list item was passed.
    if not isinstance(args.config, list) or len(args.config) == 1:
        if isinstance(args.config, list):
            args.config = args.config[0]

        config = _get_valid_config(args.config, CONFIG_MANDATORY_KEYS_CORE)

        return train_core(
            domain=args.domain,
            config=config,
            stories=story_file,
            output=output,
            train_path=train_path,
            fixed_model_name=args.fixed_model_name,
            additional_arguments=additional_arguments,
        )
    else:
        from rasa.core.train import do_compare_training

        rasa.utils.common.run_in_loop(
            do_compare_training(args, story_file, additional_arguments))
Пример #4
0
def train_core(
    args: argparse.Namespace, train_path: Optional[Text] = None
) -> Optional[Text]:
    from rasa.train import train_core
    import asyncio

    loop = asyncio.get_event_loop()
    output = train_path or args.out

    args.domain = get_validated_path(
        args.domain, "domain", DEFAULT_DOMAIN_PATH, none_is_valid=True
    )
    stories = get_validated_path(
        args.stories, "stories", DEFAULT_DATA_PATH, none_is_valid=True
    )

    _train_path = train_path or tempfile.mkdtemp()

    # Policies might be a list for the compare training. Do normal training
    # if only list item was passed.
    if not isinstance(args.config, list) or len(args.config) == 1:
        if isinstance(args.config, list):
            args.config = args.config[0]

        config = args.config or DEFAULT_CONFIG_PATH

        return train_core(
            domain=args.domain,
            config=config,
            stories=stories,
            output=output,
            train_path=train_path,
            fixed_model_name=args.fixed_model_name,
            uncompress=args.store_uncompressed,
            kwargs=extract_additional_arguments(args),
        )
    else:
        from rasa.core.train import do_compare_training

        loop.run_until_complete(do_compare_training(args, stories, None))
        return None