def __init__( self, sent_tokenizer, device: str, ext_model_name: str, config, ): super().__init__(config) ckpt_dir = download_or_load(f"bert/{ext_model_name}", config.lang) tok_path = download_or_load( f"tokenizers/bpe32k.{config.lang}.zip", config.lang, ) x = hub_utils.from_pretrained( ckpt_dir, "model.pt", load_checkpoint_heads=True, ) wrapper = BrainRobertaHubInterface( x["args"], x["task"], x["models"][0], tok_path, ) clf_dict = torch.load( f"{ckpt_dir}/classifier.pt", map_location=device, ) classifier_size = 768 if "base" in config.n_model else 1024 self._device = device self._classifier = nn.Linear(classifier_size, 1).to(device).eval() self._classifier.load_state_dict(clf_dict) self._model = wrapper.model.encoder.sentence_encoder.to(device).eval() if "cuda" in device.type: self._model = self._model.half() self._classifier = self._classifier.half() self._tokenizer = BertSumTokenizer( bpe=wrapper.bpe, dictionary=wrapper.task.source_dictionary, sent_tokenizer=sent_tokenizer, )
def load(self, device: str): """ Load user-selected task-specific model Args: device (str): device information Returns: object: User-selected task-specific model """ if self.config.n_model == "brainocr": from pororo.models.brainOCR import brainocr if self.config.lang not in self.get_available_langs(): raise ValueError( f"Unsupported Language : {self.config.lang}", 'Support Languages : ["en", "ko"]', ) det_model_path = download_or_load( f"misc/{self.detect_model}.pt", self.config.lang, ) rec_model_path = download_or_load( f"misc/{self.config.n_model}.pt", self.config.lang, ) opt_fp = download_or_load( f"misc/{self.ocr_opt}.txt", self.config.lang, ) model = brainocr.Reader( self.config.lang, det_model_ckpt_fp=det_model_path, rec_model_ckpt_fp=rec_model_path, opt_fp=opt_fp, device=device, ) model.detector.to(device) model.recognizer.to(device) return PororoOCR(model, self.config)
def load(self, device: str): """ Load user-selected task-specific model Args: device (str): device information Returns: object: User-selected task-specific model """ if self.config.n_model == "tacotron": from pororo.models.tts.synthesizer import ( MultilingualSpeechSynthesizer, ) from pororo.models.tts.utils.numerical_pinyin_converter import ( convert_from_numerical_pinyin, ) from pororo.models.tts.utils.text import jejueo_romanize, romanize tacotron_path = download_or_load("misc/tacotron2", self.config.lang) english_vocoder_path = download_or_load( "misc/hifigan_en", self.config.lang, ) korean_vocoder_path = download_or_load( "misc/hifigan_ko", self.config.lang, ) english_vocoder_config = download_or_load( "misc/hifigan_en_config.json", self.config.lang, ) korean_vocoder_config = download_or_load( "misc/hifigan_ko_config.json", self.config.lang, ) wavernn_path = download_or_load( "misc/wavernn.pyt", self.config.lang, ) synthesizer = MultilingualSpeechSynthesizer( tacotron_path, english_vocoder_path, english_vocoder_config, korean_vocoder_path, korean_vocoder_config, wavernn_path, device, self.config.lang, ) return PororoTTS( synthesizer, device, romanize, jejueo_romanize, convert_from_numerical_pinyin, self.config, )
def load(self, device: str): """ Load user-selected task-specific model Args: device (str): device information Returns: object: User-selected task-specific model """ from pororo.tasks.tokenization import PororoTokenizationFactory if self.config.n_model == "abstractive": self.config.n_model = "kobart.base.ko.summary" if self.config.n_model == "bullet": self.config.n_model = "kobart.base.ko.bullet" if self.config.n_model == "extractive": self.config.n_model = "brainbert.base.ko.summary" if "kobart" in self.config.n_model: from pororo.models.bart.KoBART import KoBartModel model_path = download_or_load( f"bart/{self.config.n_model}", self.config.lang, ) model = KoBartModel.from_pretrained( device=device, model_path=model_path, ) if "bullet" in self.config.n_model: sent_tokenizer = (lambda text: PororoTokenizationFactory( task="tokenization", lang=self.config.lang, model=f"sent_{self.config.lang}", ).load(device).predict(text)) ext_model_name = "brainbert.base.ko.summary" ext_summary = PororoRobertaSummary( sent_tokenizer, device, ext_model_name, self.config, ) return PororoKoBartBulletSummary( model=model, config=self.config, ext_summary=ext_summary, ) return PororoKoBartSummary(model=model, config=self.config) if "brainbert" in self.config.n_model: sent_tokenizer = (lambda text: PororoTokenizationFactory( task="tokenization", lang=self.config.lang, model=f"sent_{self.config.lang}", ).load(device).predict(text)) return PororoRobertaSummary( sent_tokenizer, device, self.config.n_model, self.config, )