示例#1
0
    async def _create_project_store(self, project_dir: Text) -> Dict[Text, Any]:
        default_project = RasaNLUModelConfig.DEFAULT_PROJECT_NAME

        projects = self._collect_projects(project_dir)

        project_store = {}

        if self.model_server is not None:
            project_store[default_project] = await load_from_server(
                self.component_builder,
                default_project,
                self.project_dir,
                self.remote_storage,
                self.model_server,
                self.wait_time_between_pulls,
            )
        else:
            for project in projects:
                project_store[project] = Project(
                    self.component_builder,
                    project,
                    self.project_dir,
                    self.remote_storage,
                )

            if not project_store:
                project_store[default_project] = Project(
                    project=default_project,
                    project_dir=self.project_dir,
                    remote_storage=self.remote_storage,
                )

        return project_store
示例#2
0
def test_dynamic_load_model_with_model_is_none():
    LATEST_MODEL_NAME = 'latest_model_name'

    def mocked_init(*args, **kwargs):
        return None

    def mocked_search_for_models(self):
        pass

    def mocked_latest_project_model(self):
        return LATEST_MODEL_NAME

    with mock.patch.object(Project, "__init__", mocked_init):
        with mock.patch.object(Project, "_search_for_models",
                               mocked_search_for_models):
            with mock.patch.object(Project, "_latest_project_model",
                                   mocked_latest_project_model):
                project = Project()

                project._models = ()

                project.pull_models = None

                result = project._dynamic_load_model(None)

                assert result == LATEST_MODEL_NAME
示例#3
0
    async def start_train_process(
        self,
        data_file: Text,
        project: Text,
        train_config: RasaNLUModelConfig,
        model_name: Optional[Text] = None,
    ) -> Text:
        """Start a model training."""

        if not project:
            raise InvalidProjectError("Missing project name to train")

        if self._worker_processes <= self._current_worker_processes:
            raise MaxWorkerProcessError

        if project in self.project_store:
            self.project_store[project].status = STATUS_TRAINING
        elif project not in self.project_store:
            self.project_store[project] = Project(
                self.component_builder, project, self.project_dir, self.remote_storage
            )
            self.project_store[project].status = STATUS_TRAINING

        loop = asyncio.get_event_loop()

        logger.debug("New training queued")

        self._current_worker_processes += 1
        self.project_store[project].current_worker_processes += 1

        task = loop.run_in_executor(
            self.pool,
            do_train_in_worker,
            train_config,
            data_file,
            self.project_dir,
            project,
            model_name,
            self.remote_storage,
        )

        try:
            model_path = await task
            model_dir = os.path.basename(os.path.normpath(model_path))
            self.project_store[project].update(model_dir)

            if (
                self.project_store[project].current_worker_processes == 1
                and self.project_store[project].status == STATUS_TRAINING
            ):
                self.project_store[project].status = STATUS_READY
            return model_path
        except Exception as e:
            logger.warning(e)
            self.project_store[project].status = STATUS_FAILED
            self.project_store[project].error_message = str(e)
            raise
        finally:
            self._current_worker_processes -= 1
            self.project_store[project].current_worker_processes -= 1
示例#4
0
    async def parse(self, data: Dict[Text, Any]) -> Dict[Text, Any]:
        project = data.get("project", RasaNLUModelConfig.DEFAULT_PROJECT_NAME)
        model = data.get("model")

        if project not in self.project_store:
            projects = self._list_projects(self.project_dir)

            cloud_provided_projects = self._list_projects_in_cloud()
            projects.extend(cloud_provided_projects)

            if project not in projects:
                raise InvalidProjectError(
                    "No project found with name '{}'.".format(project)
                )
            else:
                try:
                    self.project_store[project] = Project(
                        self.component_builder,
                        project,
                        self.project_dir,
                        self.remote_storage,
                    )
                except Exception as e:
                    raise InvalidProjectError(
                        "Unable to load project '{}'. Error: {}".format(project, e)
                    )

        time = data.get("time")
        response = self.project_store[project].parse(data["text"], time, model)

        if self.responses:
            self.responses.info(response)

        return self.format_response(response)
示例#5
0
def test_dynamic_load_model_with_exists_model():
    MODEL_NAME = 'model_name'

    def mocked_init(*args, **kwargs):
        return None

    with mock.patch.object(Project, "__init__", mocked_init):
        project = Project()

        project._models = (MODEL_NAME, )

        project.pull_models = None

        result = project._dynamic_load_model(MODEL_NAME)

        assert result == MODEL_NAME
示例#6
0
    def start_train_process(self,
                            data_file: Text,
                            project: Text,
                            train_config: RasaNLUModelConfig,
                            model_name: Optional[Text] = None) -> Deferred:
        """Start a model training."""

        if not project:
            raise InvalidProjectError("Missing project name to train")

        if self._training_processes <= self._current_training_processes:
            raise MaxTrainingError

        if project in self.project_store:
            self.project_store[project].status = STATUS_TRAINING
        elif project not in self.project_store:
            self.project_store[project] = Project(self.component_builder,
                                                  project, self.project_dir,
                                                  self.remote_storage)
            self.project_store[project].status = STATUS_TRAINING

        def training_callback(model_path):
            model_dir = os.path.basename(os.path.normpath(model_path))
            self.project_store[project].update(model_dir)
            self._current_training_processes -= 1
            self.project_store[project].current_training_processes -= 1
            if (self.project_store[project].status == STATUS_TRAINING
                    and self.project_store[project].current_training_processes
                    == 0):
                self.project_store[project].status = STATUS_READY
            return model_path

        def training_errback(failure):
            logger.warning(failure)

            self._current_training_processes -= 1
            self.project_store[project].current_training_processes -= 1
            self.project_store[project].status = STATUS_FAILED
            self.project_store[project].error_message = str(failure)

            return failure

        logger.debug("New training queued")

        self._current_training_processes += 1
        self.project_store[project].current_training_processes += 1

        result = self.pool.submit(do_train_in_worker,
                                  train_config,
                                  data_file,
                                  path=self.project_dir,
                                  project=project,
                                  fixed_model_name=model_name,
                                  storage=self.remote_storage)
        result = deferred_from_future(result)
        result.addCallback(training_callback)
        result.addErrback(training_errback)

        return result
示例#7
0
def test_dynamic_load_model_with_refresh_exists_model():
    MODEL_NAME = 'model_name'

    def mocked_init(*args, **kwargs):
        return None

    def mocked_search_for_models(self):
        self._models = (MODEL_NAME, )

    with mock.patch.object(Project, "__init__", mocked_init):
        with mock.patch.object(Project, '_search_for_models',
                               mocked_search_for_models):
            project = Project()

            project._models = ()

            project.pull_models = None

            result = project._dynamic_load_model(MODEL_NAME)

            assert result == MODEL_NAME