コード例 #1
0
    def load(
        cls,
        path: Union[Text, Path],
        new_config: Optional[Dict] = None,
        finetuning_epoch_fraction: float = 1.0,
    ) -> "PolicyEnsemble":
        """Loads policy and domain specification from disk."""
        metadata = cls.load_metadata(path)
        cls.ensure_model_compatibility(metadata)
        policies = []
        for i, policy_name in enumerate(metadata["policy_names"]):
            policy_cls = registry.policy_from_module_path(policy_name)
            dir_name = f"policy_{i}_{policy_cls.__name__}"
            policy_path = os.path.join(path, dir_name)

            context = {}
            if new_config:
                context["should_finetune"] = True

                config_for_policy = new_config["policies"][i]
                epochs = cls._get_updated_epochs(
                    policy_cls, config_for_policy, finetuning_epoch_fraction
                )
                if epochs:
                    context["epoch_override"] = epochs

            if "kwargs" not in rasa.shared.utils.common.arguments_of(policy_cls.load):
                if context:
                    raise UnsupportedDialogueModelError(
                        f"`{policy_cls.__name__}.{policy_cls.load.__name__}` does not "
                        f"accept `**kwargs`. Attempting to pass {context} to the "
                        f"policy. `**kwargs` should be added to all policies by "
                        f"Rasa Open Source 3.0.0."
                    )
                else:
                    rasa.shared.utils.io.raise_deprecation_warning(
                        f"`{policy_cls.__name__}.{policy_cls.load.__name__}` does not "
                        f"accept `**kwargs`. `**kwargs` are required for contextual "
                        f"information e.g. the flag `should_finetune`.",
                        warn_until_version="3.0.0",
                    )

            policy = policy_cls.load(policy_path, **context)
            cls._ensure_loaded_policy(policy, policy_cls, policy_name)
            policies.append(policy)

        ensemble_cls = rasa.shared.utils.common.class_from_module_path(
            metadata["ensemble_name"]
        )
        fingerprints = metadata.get("action_fingerprints", {})
        ensemble = ensemble_cls(policies, fingerprints)
        return ensemble
コード例 #2
0
ファイル: ensemble.py プロジェクト: tgalery/rasa_nlu
    def ensure_model_compatibility(metadata, version_to_check=None):
        from packaging import version

        if version_to_check is None:
            version_to_check = rasa.constants.MINIMUM_COMPATIBLE_VERSION

        model_version = metadata.get("rasa", "0.0.0")
        if version.parse(model_version) < version.parse(version_to_check):
            raise UnsupportedDialogueModelError(
                "The model version is to old to be "
                "loaded by this Rasa Core instance. "
                "Either retrain the model, or run with"
                "an older version. "
                "Model version: {} Instance version: {} "
                "Minimal compatible version: {}"
                "".format(model_version, rasa.__version__, version_to_check),
                model_version)
コード例 #3
0
ファイル: policy.py プロジェクト: wavymazy/rasa
    def load(cls, path: Union[Text, Path], **kwargs: Any) -> "Policy":
        """Loads a policy from path.

        Args:
            path: Path to load policy from.

        Returns:
            An instance of `Policy`.
        """
        metadata_file = Path(path) / cls._metadata_filename()

        if metadata_file.is_file():
            data = json.loads(rasa.shared.utils.io.read_file(metadata_file))

            if (Path(path) / FEATURIZER_FILE).is_file():
                featurizer = TrackerFeaturizer.load(path)
                data["featurizer"] = featurizer

            data.update(kwargs)

            constructor_args = rasa.shared.utils.common.arguments_of(cls)
            if "kwargs" not in constructor_args:
                if set(data.keys()).issubset(set(constructor_args)):
                    rasa.shared.utils.io.raise_deprecation_warning(
                        f"`{cls.__name__}.__init__` does not accept `**kwargs` "
                        f"This is required for contextual information e.g. the flag "
                        f"`should_finetune`.",
                        warn_until_version="3.0.0",
                    )
                else:
                    raise UnsupportedDialogueModelError(
                        f"`{cls.__name__}.__init__` does not accept `**kwargs`. "
                        f"Attempting to pass {data} to the policy. "
                        f"This argument should be added to all policies by "
                        f"Rasa Open Source 3.0.0."
                    )

            return cls(**data)

        logger.info(
            f"Couldn't load metadata for policy '{cls.__name__}'. "
            f"File '{metadata_file}' doesn't exist."
        )
        return cls()