示例#1
0
    def from_archive(
        cls,
        archive: Archive,
        predictor_name: str = None,
        dataset_reader_to_load: str = "validation",
        frozen: bool = True,
        language: str = "en_core_web_sm",
        restrict_frames: bool = False,
        restrict_roles: bool = False,
    ) -> "Predictor":
        # Duplicate the config so that the config inside the archive doesn't get consumed
        config = archive.config.duplicate()

        if not predictor_name:
            model_type = config.get("model").get("type")
            model_class, _ = Model.resolve_class_name(model_type)
            predictor_name = model_class.default_predictor
        predictor_class: Type[Predictor] = (
            Predictor.by_name(predictor_name) if predictor_name is not None else cls  # type: ignore
        )

        if dataset_reader_to_load == "validation" and "validation_dataset_reader" in config:
            dataset_reader_params = config["validation_dataset_reader"]
        else:
            dataset_reader_params = config["dataset_reader"]
        dataset_reader = DatasetReader.from_params(dataset_reader_params)

        model = archive.model
        if frozen:
            model.restrict_frames = restrict_frames
            model.restrict_roles = restrict_roles
            model.eval()

        return predictor_class(model, dataset_reader, language)
示例#2
0
def predictor_from_archive(archive: Archive,
                           predictor_name: str = None,
                           paper_features_path: str = None) -> 'Predictor':
    """
    Extends allennlp.predictors.predictor.from_archive to allow processing multiprocess reader

    paper_features_path is passed to replace the correct one if the dataset_reader is multiprocess
    """

    # Duplicate the config so that the config inside the archive doesn't get consumed
    config = archive.config.duplicate()

    if not predictor_name:
        model_type = config.get("model").get("type")
        if not model_type in DEFAULT_PREDICTORS:
            raise ConfigurationError(f"No default predictor for model type {model_type}.\n"\
                                     f"Please specify a predictor explicitly.")
        predictor_name = DEFAULT_PREDICTORS[model_type]

    dataset_config = config["dataset_reader"].as_dict()
    if dataset_config['type'] == 'multiprocess':
        dataset_config = dataset_config['base_reader']
        if paper_features_path:
            dataset_config['paper_features_path'] = paper_features_path
        dataset_reader_params = Params(dataset_config)

    else:
        dataset_reader_params = config["dataset_reader"]

    dataset_reader = DatasetReader.from_params(dataset_reader_params)

    model = archive.model
    model.eval()

    return Predictor.by_name(predictor_name)(model, dataset_reader)
示例#3
0
    def __init__(self, model: MultiTaskModel,
                 dataset_reader: MultiTaskDatasetReader) -> None:
        if not isinstance(dataset_reader, MultiTaskDatasetReader):
            raise ConfigurationError(self._WRONG_READER_ERROR)

        if not isinstance(model, MultiTaskModel):
            raise ConfigurationError(
                "MultiTaskPredictor is designed to work only with MultiTaskModel."
            )

        super().__init__(model, dataset_reader)

        self.predictors = {}
        for name, head in model._heads.items():
            predictor_name = head.default_predictor
            predictor_class: Type[Predictor] = (
                Predictor.by_name(predictor_name)
                if predictor_name is not None else Predictor  # type: ignore
            )
            self.predictors[name] = predictor_class(
                model, dataset_reader.readers[name].inner)