示例#1
0
    def __init__(
        self,
        sent_tokenizer,
        device: str,
        ext_model_name: str,
        config,
    ):
        super().__init__(config)
        ckpt_dir = download_or_load(f"bert/{ext_model_name}", config.lang)
        tok_path = download_or_load(
            f"tokenizers/bpe32k.{config.lang}.zip",
            config.lang,
        )

        x = hub_utils.from_pretrained(
            ckpt_dir,
            "model.pt",
            load_checkpoint_heads=True,
        )

        wrapper = BrainRobertaHubInterface(
            x["args"],
            x["task"],
            x["models"][0],
            tok_path,
        )

        clf_dict = torch.load(
            f"{ckpt_dir}/classifier.pt",
            map_location=device,
        )

        classifier_size = 768 if "base" in config.n_model else 1024

        self._device = device
        self._classifier = nn.Linear(classifier_size, 1).to(device).eval()
        self._classifier.load_state_dict(clf_dict)
        self._model = wrapper.model.encoder.sentence_encoder.to(device).eval()

        if "cuda" in device.type:
            self._model = self._model.half()
            self._classifier = self._classifier.half()

        self._tokenizer = BertSumTokenizer(
            bpe=wrapper.bpe,
            dictionary=wrapper.task.source_dictionary,
            sent_tokenizer=sent_tokenizer,
        )
    def load(self, device: str):
        """
        Load user-selected task-specific model

        Args:
            device (str): device information

        Returns:
            object: User-selected task-specific model

        """
        if self.config.n_model == "brainocr":
            from pororo.models.brainOCR import brainocr

            if self.config.lang not in self.get_available_langs():
                raise ValueError(
                    f"Unsupported Language : {self.config.lang}",
                    'Support Languages : ["en", "ko"]',
                )

            det_model_path = download_or_load(
                f"misc/{self.detect_model}.pt",
                self.config.lang,
            )
            rec_model_path = download_or_load(
                f"misc/{self.config.n_model}.pt",
                self.config.lang,
            )
            opt_fp = download_or_load(
                f"misc/{self.ocr_opt}.txt",
                self.config.lang,
            )
            model = brainocr.Reader(
                self.config.lang,
                det_model_ckpt_fp=det_model_path,
                rec_model_ckpt_fp=rec_model_path,
                opt_fp=opt_fp,
                device=device,
            )
            model.detector.to(device)
            model.recognizer.to(device)
            return PororoOCR(model, self.config)
示例#3
0
    def load(self, device: str):
        """
        Load user-selected task-specific model

        Args:
            device (str): device information

        Returns:
            object: User-selected task-specific model

        """
        if self.config.n_model == "tacotron":
            from pororo.models.tts.synthesizer import (
                MultilingualSpeechSynthesizer,
            )
            from pororo.models.tts.utils.numerical_pinyin_converter import (
                convert_from_numerical_pinyin,
            )
            from pororo.models.tts.utils.text import jejueo_romanize, romanize

            tacotron_path = download_or_load("misc/tacotron2", self.config.lang)
            english_vocoder_path = download_or_load(
                "misc/hifigan_en",
                self.config.lang,
            )
            korean_vocoder_path = download_or_load(
                "misc/hifigan_ko",
                self.config.lang,
            )
            english_vocoder_config = download_or_load(
                "misc/hifigan_en_config.json",
                self.config.lang,
            )
            korean_vocoder_config = download_or_load(
                "misc/hifigan_ko_config.json",
                self.config.lang,
            )
            wavernn_path = download_or_load(
                "misc/wavernn.pyt",
                self.config.lang,
            )
            synthesizer = MultilingualSpeechSynthesizer(
                tacotron_path,
                english_vocoder_path,
                english_vocoder_config,
                korean_vocoder_path,
                korean_vocoder_config,
                wavernn_path,
                device,
                self.config.lang,
            )
            return PororoTTS(
                synthesizer,
                device,
                romanize,
                jejueo_romanize,
                convert_from_numerical_pinyin,
                self.config,
            )
示例#4
0
    def load(self, device: str):
        """
        Load user-selected task-specific model

        Args:
            device (str): device information

        Returns:
            object: User-selected task-specific model

        """
        from pororo.tasks.tokenization import PororoTokenizationFactory

        if self.config.n_model == "abstractive":
            self.config.n_model = "kobart.base.ko.summary"

        if self.config.n_model == "bullet":
            self.config.n_model = "kobart.base.ko.bullet"

        if self.config.n_model == "extractive":
            self.config.n_model = "brainbert.base.ko.summary"

        if "kobart" in self.config.n_model:
            from pororo.models.bart.KoBART import KoBartModel
            model_path = download_or_load(
                f"bart/{self.config.n_model}",
                self.config.lang,
            )

            model = KoBartModel.from_pretrained(
                device=device,
                model_path=model_path,
            )

            if "bullet" in self.config.n_model:
                sent_tokenizer = (lambda text: PororoTokenizationFactory(
                    task="tokenization",
                    lang=self.config.lang,
                    model=f"sent_{self.config.lang}",
                ).load(device).predict(text))

                ext_model_name = "brainbert.base.ko.summary"
                ext_summary = PororoRobertaSummary(
                    sent_tokenizer,
                    device,
                    ext_model_name,
                    self.config,
                )

                return PororoKoBartBulletSummary(
                    model=model,
                    config=self.config,
                    ext_summary=ext_summary,
                )

            return PororoKoBartSummary(model=model, config=self.config)

        if "brainbert" in self.config.n_model:
            sent_tokenizer = (lambda text: PororoTokenizationFactory(
                task="tokenization",
                lang=self.config.lang,
                model=f"sent_{self.config.lang}",
            ).load(device).predict(text))

            return PororoRobertaSummary(
                sent_tokenizer,
                device,
                self.config.n_model,
                self.config,
            )