Beispiel #1
0
    async def train(request: Request):
        """Train a Rasa Model."""
        from rasa.train import train_async

        validate_request_body(
            request,
            "You must provide training data in the request body in order to "
            "train your model.",
        )

        rjs = request.json
        validate_request(rjs)

        # create a temporary directory to store config, domain and
        # training data
        temp_dir = tempfile.mkdtemp()

        config_path = os.path.join(temp_dir, "config.yml")
        dump_obj_as_str_to_file(config_path, rjs["config"])

        if "nlu" in rjs:
            nlu_path = os.path.join(temp_dir, "nlu.md")
            dump_obj_as_str_to_file(nlu_path, rjs["nlu"])

        if "stories" in rjs:
            stories_path = os.path.join(temp_dir, "stories.md")
            dump_obj_as_str_to_file(stories_path, rjs["stories"])

        domain_path = DEFAULT_DOMAIN_PATH
        if "domain" in rjs:
            domain_path = os.path.join(temp_dir, "domain.yml")
            dump_obj_as_str_to_file(domain_path, rjs["domain"])

        try:
            model_path = await train_async(
                domain=domain_path,
                config=config_path,
                training_files=temp_dir,
                output_path=rjs.get("out", DEFAULT_MODELS_PATH),
                force_training=rjs.get("force", False),
            )

            filename = os.path.basename(model_path) if model_path else None

            return await response.file(
                model_path, filename=filename, headers={"filename": filename}
            )
        except InvalidDomain as e:
            raise ErrorResponse(
                400,
                "InvalidDomainError",
                "Provided domain file is invalid. Error: {}".format(e),
            )
        except Exception as e:
            logger.debug(traceback.format_exc())
            raise ErrorResponse(
                500,
                "TrainingError",
                "An unexpected error occurred during training. Error: {}".format(e),
            )
Beispiel #2
0
    async def post_data_convert(request: Request):
        """Converts current domain in yaml or json format."""
        validate_request_body(
            request,
            "You must provide training data in the request body in order to "
            "train your model.",
        )
        rjs = request.json

        if 'data' not in rjs:
            raise ErrorResponse(
                400, "BadRequest",
                "Must provide training data in 'data' property")
        if 'output_format' not in rjs or rjs["output_format"] not in [
                "json", "md"
        ]:
            raise ErrorResponse(
                400, "BadRequest",
                "'output_format' is required and must be either 'md' or 'json")
        if 'language' not in rjs:
            raise ErrorResponse(400, "BadRequest", "'language' is required")

        temp_dir = tempfile.mkdtemp()
        out_dir = tempfile.mkdtemp()

        nlu_data_path = os.path.join(temp_dir, "nlu_data")
        output_path = os.path.join(out_dir, "output")
        # botfront: several nlu files
        if type(rjs["data"] is dict):
            from rasa.core.utils import dump_obj_as_json_to_file
            dump_obj_as_json_to_file(nlu_data_path, rjs["data"])
        else:
            dump_obj_as_str_to_file(nlu_data_path, rjs["data"])

        # botfront end
        from rasa.nlu.convert import convert_training_data
        convert_training_data(nlu_data_path, output_path, rjs["output_format"],
                              rjs["language"])

        with open(output_path, encoding='utf-8') as f:
            data = f.read()

        if rjs["output_format"] == 'json':
            import json
            data = json.loads(data, encoding='utf-8')

        return response.json({"data": data})
Beispiel #3
0
    async def train_stack(request: Request):
        """Train a Rasa Stack model."""

        from rasa.train import train_async

        rjs = request.json

        # create a temporary directory to store config, domain and
        # training data
        temp_dir = tempfile.mkdtemp()

        try:
            config_path = os.path.join(temp_dir, "config.yml")
            dump_obj_as_str_to_file(config_path, rjs["config"])

            domain_path = os.path.join(temp_dir, "domain.yml")
            dump_obj_as_str_to_file(domain_path, rjs["domain"])

            nlu_path = os.path.join(temp_dir, "nlu.md")
            dump_obj_as_str_to_file(nlu_path, rjs["nlu"])

            stories_path = os.path.join(temp_dir, "stories.md")
            dump_obj_as_str_to_file(stories_path, rjs["stories"])
        except KeyError:
            raise ErrorResponse(
                400,
                "TrainingError",
                "The Rasa Stack training request is "
                "missing a key. The required keys are "
                "`config`, `domain`, `nlu` and `stories`.",
            )

        # the model will be saved to the same temporary dir
        # unless `out` was specified in the request
        try:
            model_path = await train_async(
                domain=domain_path,
                config=config_path,
                training_files=[nlu_path, stories_path],
                output=rjs.get("out", temp_dir),
                force_training=rjs.get("force", False),
            )

            return await response.file(model_path)
        except Exception as e:
            raise ErrorResponse(
                400,
                "TrainingError",
                "Rasa Stack model could not be trained. Error: {}".format(e),
            )
Beispiel #4
0
    async def train(request: Request):
        """Train a Rasa Model."""
        from rasa.train import train_async

        validate_request_body(
            request,
            "You must provide training data in the request body in order to "
            "train your model.",
        )

        rjs = request.json
        validate_request(rjs)

        # create a temporary directory to store config, domain and
        # training data
        temp_dir = tempfile.mkdtemp()

        # bf mod
        config_paths = {}

        config_dir = os.path.join(temp_dir, 'config')
        os.mkdir(config_dir)

        for key in rjs["config"].keys():
            config_file_path = os.path.join(config_dir, "{}.yml".format(key))
            dump_obj_as_str_to_file(config_file_path, rjs["config"][key])
            config_paths[key] = config_file_path

        if "nlu" in rjs:
            nlu_dir = os.path.join(temp_dir, 'nlu')
            os.mkdir(nlu_dir)

            for key in rjs["nlu"].keys():
                nlu_file_path = os.path.join(nlu_dir, "{}.md".format(key))
                dump_obj_as_str_to_file(nlu_file_path, rjs["nlu"][key]["data"])
        # /bf mod

        if "stories" in rjs:
            stories_path = os.path.join(temp_dir, "stories.md")
            dump_obj_as_str_to_file(stories_path, rjs["stories"])

        domain_path = DEFAULT_DOMAIN_PATH
        if "domain" in rjs:
            domain_path = os.path.join(temp_dir, "domain.yml")
            dump_obj_as_str_to_file(domain_path, rjs["domain"])

        try:
            model_path = await train_async(
                domain=domain_path,
                config=config_paths,
                training_files=temp_dir,
                output_path=rjs.get("out", DEFAULT_MODELS_PATH),
                force_training=rjs.get("force", False),
                # botfront: add the possibility to pass a fixed name in the json payload
                fixed_model_name=rjs.get("fixed_model_name", None),
                # persist data file for nlu components to use
                persist_nlu_training_data=True,
            )

            filename = os.path.basename(model_path) if model_path else None

            return await response.file(model_path,
                                       filename=filename,
                                       headers={"filename": filename})
        except InvalidDomain as e:
            raise ErrorResponse(
                400,
                "InvalidDomainError",
                "Provided domain file is invalid. Error: {}".format(e),
            )
        except Exception as e:
            logger.debug(traceback.format_exc())
            raise ErrorResponse(
                500,
                "TrainingError",
                "An unexpected error occurred during training. Error: {}".
                format(e),
            )