Ejemplo n.º 1
0
    def from_config(cls, config: Config):
        # TODO: T57433776 remove once FairSeq support PathManager
        config.bpe_encoder_path = PathManager.get_local_path(config.bpe_encoder_path)
        config.bpe_vocab_path = PathManager.get_local_path(config.bpe_vocab_path)

        bpe = create_gpt2_bpe(config.bpe_encoder_path, config.bpe_vocab_path)
        # This hacks the bpe instance to be picklable
        bpe = copy.copy(bpe)
        bpe.__class__ = PickleableGPT2BPEEncoder

        return cls(bpe)
Ejemplo n.º 2
0
    def __init__(self, sp_model_path: Optional[str] = None):
        super().__init__()

        # This default spm file path is a dummy link as we haven't published
        # the file yet. Please provide your own spm file path when using
        # this transform
        sp_model_path = sp_model_path or url.URL[url.SP_MODEL]
        local_path = PathManager.get_local_path(sp_model_path)
        self.sp_model = load_sp_model(local_path)
Ejemplo n.º 3
0
    def _load_processor(self):
        sp_model_path = PathManager.get_local_path(self.sp_model_path)
        if self.use_fb_sentencepiece:
            self.processor = torch.classes.fb.SentencePiece.fromFile(sp_model_path)
        else:
            from sentencepiece import SentencePieceProcessor

            self.processor = SentencePieceProcessor()
            self.processor.Load(sp_model_path)
Ejemplo n.º 4
0
    def _load_processor(self):
        if getattr(self, "use_fb_sentencepiece", None):
            try:
                import importlib.resources

                import sentencepiece_model

                with importlib.resources.path(sentencepiece_model,
                                              "model") as sp_model_path:
                    self.processor = torch.classes.fb.SentencePiece.fromFile(
                        str(sp_model_path))
            except Exception:
                sp_model_path = PathManager.get_local_path(self.sp_model_path)
                self.processor = torch.classes.fb.SentencePiece.fromFile(
                    sp_model_path)
        else:
            from sentencepiece import SentencePieceProcessor

            sp_model_path = PathManager.get_local_path(self.sp_model_path)
            self.processor = SentencePieceProcessor()
            self.processor.Load(sp_model_path)
Ejemplo n.º 5
0
def create_module(
    module_config, *args, create_fn=_create_module_from_registry, **kwargs
):
    """Create module object given the module's config object. It depends on the
    global shared module registry. Hence, your module must be available for the
    registry. This entails that your module must be imported somewhere in the
    code path during module creation (ideally in your model class) for the module
    to be visible for registry.

    Args:
        module_config (type): Module config object.
        create_fn (type): The function to use for creating the module. Use this
            parameter if your module creation requires custom code and pass your
            function here. Defaults to `_create_module_from_registry()`.

    Returns:
        type: Description of returned object.

    """
    # the first module with a given shared_module_key and type is saved in
    # SHARED_MODULE_REGISTRY.  The rest will reuse the saved module and thus
    # share parameters.
    shared_module_key = getattr(module_config, "shared_module_key", None)
    typed_shared_module_key = (shared_module_key, type(module_config))
    load_path = getattr(module_config, "load_path", None)
    is_torchscript_load_path = load_path and zipfile.is_zipfile(
        PathManager.get_local_path(load_path)
    )
    module = SHARED_MODULE_REGISTRY.get(typed_shared_module_key)
    if not module:
        if is_torchscript_load_path:
            with PathManager.open(load_path, "rb") as load_file:
                module = torch.jit.load(load_file)
        else:
            module = create_fn(module_config, *args, **kwargs)

    name = type(module).__name__
    if load_path and not is_torchscript_load_path:
        print(f"Loading state of module {name} from {load_path} ...")
        with PathManager.open(load_path, "rb") as load_file:
            module.load_state_dict(torch.load(load_file, map_location="cpu"))
    if getattr(module_config, "freeze", False):
        print(f"Freezing the parameters of module {name} ...")
        module.freeze()
    if shared_module_key:
        SHARED_MODULE_REGISTRY[typed_shared_module_key] = module
    module.save_path = getattr(module_config, "save_path", None)
    return module
Ejemplo n.º 6
0
def batch_predict_caffe2_model(
    pytext_model_file: str,
    caffe2_model_file: str,
    db_type: str = CAFFE2_DB_TYPE,
    data_source: Optional[DataSource] = None,
    use_cuda=False,
    task: Optional[NewTask] = None,
    train_config: Optional[PyTextConfig] = None,
    cache_size: int = 0,
):
    """
    Gets predictions from caffe2 model from a batch of examples.

    Args:
        pytext_model_file: Path to pytext model file (required if task and
            training config is not specified)
        caffe2_model_file: Path to caffe2 model file
        db_type: DB type to use for caffe2
        data_source: Data source for test examples
        use_cuda: Whether to turn on cuda processing
        task: The pytext task object
        train_config: The pytext training config
        cache_size: The LRU cache size to use for prediction. 0 = no cache,
            -1 = boundless cache, [1, inf) = size of cache
    """
    logging.info(f"Loading data processing config from {pytext_model_file}")

    _set_cuda(use_cuda)
    if task is None or train_config is None:
        task, train_config, _ = load(pytext_model_file)

    data_source = data_source or task.data.data_source
    logging.info(f"Loading Caffe2 model: {caffe2_model_file}")
    predictor = create_predictor(
        train_config,
        PathManager.get_local_path(caffe2_model_file),
        db_type,
        task,
        cache_size,
    )
    logging.info(f"Model loaded, start testing")
    predictions = [predictor(example) for example in data_source.test]
    return predictions
Ejemplo n.º 7
0
def create_predictor(
    config: PyTextConfig,
    model_file: Optional[str] = None,
    db_type: str = CAFFE2_DB_TYPE,
    task: Optional[NewTask] = None,
    cache_size: int = 0,
) -> Predictor:
    """
    Create a simple prediction API from a training config and an exported caffe2
    model file. This model file should be created by calling export on a trained
    model snapshot.
    """
    workspace_id = str(uuid.uuid4())
    workspace.SwitchWorkspace(workspace_id, True)
    predict_net = predictor_exporter.prepare_prediction_net(
        filename=model_file
        or PathManager.get_local_path(config.export_caffe2_path),
        db_type=db_type,
    )

    new_task = task or NewTask.from_config(config.task)
    input_tensorizers = {
        name: tensorizer
        for name, tensorizer in new_task.data.tensorizers.items()
        if tensorizer.is_input
    }

    def predict_fn(input):
        return _predict(workspace_id, predict_net, new_task.model,
                        input_tensorizers, input)

    if cache_size < 0:
        return lru_cache(maxsize=None)(predict_fn)
    elif cache_size > 0:
        return lru_cache(maxsize=cache_size)(predict_fn)
    else:
        return predict_fn
Ejemplo n.º 8
0
 def _load_processor(self):
     self.processor = SentencePieceProcessor()
     self.processor.Load(PathManager.get_local_path(self.sp_model_path))
Ejemplo n.º 9
0
    def __init__(
        self, config: Config, output_encoded_layers: bool, *args, **kwargs
    ) -> None:
        super().__init__(config, output_encoded_layers=output_encoded_layers)
        # Load config
        config_file = os.path.join(config.bert_cpt_dir, "config.json")
        local_config_path = PathManager.get_local_path(config_file)
        bert_config = BertConfig.from_json_file(local_config_path)
        print("Bert model config {}".format(bert_config))
        # Instantiate model.
        model = BertModel(bert_config)
        weights_path = os.path.join(config.bert_cpt_dir, "pytorch_model.bin")
        # load pre-trained weights if weights_path exists
        if config.load_weights and PathManager.isfile(weights_path):
            with PathManager.open(weights_path, "rb") as fd:
                state_dict = torch.load(fd)

            missing_keys: List[str] = []
            unexpected_keys: List[str] = []
            error_msgs: List[str] = []
            # copy state_dict so _load_from_state_dict can modify it
            metadata = getattr(state_dict, "_metadata", None)
            for key in list(state_dict.keys()):
                new_key = None
                if key.endswith("LayerNorm.gamma"):  # compatibility with v0.5 models
                    new_key = key.replace("LayerNorm.gamma", "LayerNorm.weight")
                if key.endswith("LayerNorm.beta"):  # compatibility with v0.5 models
                    new_key = key.replace("LayerNorm.beta", "LayerNorm.bias")
                if new_key is not None:
                    state_dict[new_key] = state_dict.pop(key)

            if metadata is not None:
                state_dict._metadata = metadata

            def load(module, prefix=""):
                local_metadata = (
                    {} if metadata is None else metadata.get(prefix[:-1], {})
                )
                module._load_from_state_dict(
                    state_dict,
                    prefix,
                    local_metadata,
                    True,
                    missing_keys,
                    unexpected_keys,
                    error_msgs,
                )
                for name, child in module._modules.items():
                    if child is not None:
                        load(child, prefix + name + ".")

            load(model, prefix="" if hasattr(model, "bert") else "bert.")
            if len(missing_keys) > 0:
                print(
                    "Weights of {} not initialized from pretrained model: {}".format(
                        model.__class__.__name__, missing_keys
                    )
                )
            if len(unexpected_keys) > 0:
                print(
                    "Weights from pretrained model not used in {}: {}".format(
                        model.__class__.__name__, unexpected_keys
                    )
                )

        self.bert = model
        log_class_usage(__class__)