Ejemplo n.º 1
0
    def __init__(self,
                 ckpt_path='output/macbert4csc/epoch=09-val_loss=0.01.ckpt',
                 vocab_path='output/macbert4csc/vocab.txt',
                 cfg_path='train_macbert4csc.yml'):
        logger.debug("device: {}".format(device))
        self.tokenizer = BertTokenizer.from_pretrained(vocab_path)
        cfg.merge_from_file(cfg_path)
        if 'macbert4csc' in cfg_path:
            self.model = MacBert4Csc.load_from_checkpoint(
                checkpoint_path=ckpt_path,
                cfg=cfg,
                map_location=device,
                tokenizer=self.tokenizer)
        elif 'softmaskedbert4csc' in cfg_path:
            self.model = SoftMaskedBert4Csc.load_from_checkpoint(
                checkpoint_path=ckpt_path,
                cfg=cfg,
                map_location=device,
                tokenizer=self.tokenizer)
        else:
            raise ValueError("model not found.")

        self.model.eval()
        self.model.to(device)
        logger.debug("device: {}".format(device))
Ejemplo n.º 2
0
def _fetch_from_remote(url,
                       force_download=False,
                       cached_dir='~/.paddle-ernie-cache'):
    import hashlib, requests, tarfile
    sig = hashlib.md5(url.encode('utf8')).hexdigest()
    cached_dir = Path(cached_dir).expanduser()
    try:
        cached_dir.mkdir()
    except OSError:
        pass
    cached_dir_model = cached_dir / sig
    if force_download or not cached_dir_model.exists():
        cached_dir_model.mkdir()
        tmpfile = cached_dir_model / 'tmp'
        with tmpfile.open('wb') as f:
            # url = 'https://ernie.bj.bcebos.com/ERNIE_stable.tgz'
            r = requests.get(url, stream=True)
            total_len = int(r.headers.get('content-length'))
            for chunk in tqdm(r.iter_content(chunk_size=1024),
                              total=total_len // 1024,
                              desc='downloading %s' % url,
                              unit='KB'):
                if chunk:
                    f.write(chunk)
                    f.flush()
            logger.debug('extacting... to %s' % tmpfile)
            with tarfile.open(tmpfile.as_posix()) as tf:
                tf.extractall(path=cached_dir_model.as_posix())
        os.remove(tmpfile.as_posix())
    logger.debug('%s cached in %s' % (url, cached_dir))
    return cached_dir_model
Ejemplo n.º 3
0
    def initialize_detector(self):
        t1 = time.time()
        if self.enable_rnnlm:
            self.lm = LM(self.rnnlm_model_dir, self.rnnlm_vocab_path)
            logger.debug('Loaded language model: %s, spend: %s s' %
                         (self.rnnlm_model_dir, str(time.time() - t1)))
        else:
            try:
                import kenlm
            except ImportError:
                raise ImportError(
                    'pycorrector dependencies are not fully installed, '
                    'they are required for statistical language model.'
                    'Please use "pip install kenlm" to install it, not support Win.'
                    'if you are Win, Please install tensorflow and set enable_rnnlm=True.'
                )

            self.lm = kenlm.Model(self.language_model_path)
            logger.debug('Loaded language model: %s, spend: %s s' %
                         (self.language_model_path, str(time.time() - t1)))

        # 词、频数dict
        t2 = time.time()
        self.word_freq = self.load_word_freq_dict(self.word_freq_path)
        t3 = time.time()
        logger.debug('Loaded word freq file: %s, size: %d, spend: %s s' %
                     (self.word_freq_path, len(self.word_freq), str(t3 - t2)))
        # 自定义混淆集
        self.custom_confusion = self._get_custom_confusion_dict(
            self.custom_confusion_path)
        t4 = time.time()
        logger.debug('Loaded confusion file: %s, size: %d, spend: %s s' %
                     (self.custom_confusion_path, len(
                         self.custom_confusion), str(t4 - t3)))
        # 自定义切词词典
        self.custom_word_freq = self.load_word_freq_dict(
            self.custom_word_freq_path)
        self.person_names = self.load_word_freq_dict(self.person_name_path)
        self.place_names = self.load_word_freq_dict(self.place_name_path)
        self.stopwords = self.load_word_freq_dict(self.stopwords_path)
        # 合并切词词典及自定义词典
        self.custom_word_freq.update(self.person_names)
        self.custom_word_freq.update(self.place_names)
        self.custom_word_freq.update(self.stopwords)

        self.word_freq.update(self.custom_word_freq)
        t5 = time.time()
        logger.debug('Loaded custom word file: %s, size: %d, spend: %s s' %
                     (self.custom_confusion_path, len(
                         self.custom_word_freq), str(t5 - t4)))
        self.tokenizer = Tokenizer(dict_path=self.word_freq_path,
                                   custom_word_freq_dict=self.custom_word_freq,
                                   custom_confusion_dict=self.custom_confusion)
        t6 = time.time()
        logger.info('Loaded dict ok, spend: %s s' % str(t6 - t1))
        self.initialized_detector = True
Ejemplo n.º 4
0
 def __init__(self, bert_model_dir=os.path.join(pwd_path, '../data/bert_models/chinese_finetuned_lm/')):
     super(BertCorrector, self).__init__()
     self.name = 'bert_corrector'
     t1 = time.time()
     self.model = pipeline('fill-mask',
                           model=bert_model_dir,
                           tokenizer=bert_model_dir)
     if self.model:
         self.mask = self.model.tokenizer.mask_token
         logger.debug('Loaded bert model: %s, spend: %.3f s.' % (bert_model_dir, time.time() - t1))
Ejemplo n.º 5
0
 def initialize_bert_corrector(self):
     t1 = time.time()
     self.bert_tokenizer = BertTokenizer(self.bert_model_vocab)
     self.MASK_ID = self.bert_tokenizer.convert_tokens_to_ids([MASK_TOKEN
                                                               ])[0]
     # Prepare model
     self.model = BertForMaskedLM.from_pretrained(self.bert_model_dir)
     logger.debug("Loaded model ok, path: %s, spend: %.3f s." %
                  (self.bert_model_dir, time.time() - t1))
     self.initialized_bert_corrector = True
Ejemplo n.º 6
0
 def set_en_custom_confusion_dict(self, path):
     """
     设置混淆纠错词典
     :param path:
     :return:
     """
     self.check_init()
     self.custom_confusion = self._get_custom_confusion_dict(path)
     logger.debug('Loaded en spell confusion path: %s, size: %d' %
                  (path, len(self.custom_confusion)))
Ejemplo n.º 7
0
 def initialize_corrector(self):
     t1 = time.time()
     # chinese common char dict
     self.cn_char_set = load_char_set(self.common_char_path)
     # same pinyin
     self.same_pinyin = load_same_pinyin(self.same_pinyin_text_path)
     # same stroke
     self.same_stroke = load_same_stroke(self.same_stroke_text_path)
     logger.debug("Loaded same pinyin file: %s, same stroke file: %s, spend: %.3f s." % (
         self.same_pinyin_text_path, self.same_stroke_text_path, time.time() - t1))
     self.initialized_corrector = True
Ejemplo n.º 8
0
 def _init(self):
     with gzip.open(config.en_dict_path, "rb") as f:
         all_word_freq_dict = json.loads(f.read(), encoding="utf-8")
         word_freq = {}
         for k, v in all_word_freq_dict.items():
             # 英语常用单词3万个,取词频高于400
             if v > 400:
                 word_freq[k] = v
         self.word_freq_dict = word_freq
         logger.debug("load en spell data: %s, size: %d" %
                      (config.en_dict_path, len(self.word_freq_dict)))
Ejemplo n.º 9
0
 def set_custom_word(self, path):
     self.check_detector_initialized()
     word_freqs = self.load_word_freq_dict(path)
     # 合并字典
     self.custom_word_freq.update(word_freqs)
     # 合并切词词典及自定义词典
     self.word_freq.update(self.custom_word_freq)
     self.tokenizer = Tokenizer(dict_path=self.word_freq_path, custom_word_freq_dict=self.custom_word_freq,
                                custom_confusion_dict=self.custom_confusion)
     for k, v in word_freqs.items():
         self.set_word_frequency(k, v)
     logger.debug('Loaded custom word path: %s, size: %d' % (path, len(word_freqs)))
Ejemplo n.º 10
0
 def detect(self, sentence):
     """
     句子改错
     :param sentence: 句子文本
     :return: list[list], [error_word, begin_pos, end_pos, error_type]
     """
     maybe_errors = []
     for prob, f in self.predict_token_prob(sentence):
         logger.debug('prob:%s, token:%s, idx:%s' % (prob, f.token, f.id))
         if prob < self.threshold:
             maybe_errors.append([f.token, f.id, f.id + 1, ErrorType.char])
     return maybe_errors
Ejemplo n.º 11
0
 def __init__(self, macbert_model_dir=config.macbert_model_dir):
     super(MacBertCorrector, self).__init__()
     self.name = 'macbert_corrector'
     t1 = time.time()
     if not os.path.exists(os.path.join(macbert_model_dir, 'vocab.txt')):
         macbert_model_dir = "shibing624/macbert4csc-base-chinese"
     self.tokenizer = BertTokenizer.from_pretrained(macbert_model_dir)
     self.model = BertForMaskedLM.from_pretrained(macbert_model_dir)
     self.model.to(device)
     self.unk_tokens = [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤']
     logger.debug("device: {}".format(device))
     logger.debug('Loaded macbert model: %s, spend: %.3f s.' %
                  (macbert_model_dir, time.time() - t1))
Ejemplo n.º 12
0
 def __init__(self, bert_model_dir=config.bert_model_dir, device=device_id):
     super(BertCorrector, self).__init__()
     self.name = 'bert_corrector'
     t1 = time.time()
     self.model = pipeline(
         'fill-mask',
         model=bert_model_dir,
         tokenizer=bert_model_dir,
         device=device,  # gpu device id
     )
     if self.model:
         self.mask = self.model.tokenizer.mask_token
         logger.debug('Loaded bert model: %s, spend: %.3f s.' % (bert_model_dir, time.time() - t1))
Ejemplo n.º 13
0
    def __init__(self, d_model_dir=D_model_dir, g_model_dir=G_model_dir):
        super(ElectraCorrector, self).__init__()
        self.name = 'electra_corrector'
        t1 = time.time()
        self.g_model = pipeline("fill-mask",
                                model=g_model_dir,
                                tokenizer=g_model_dir)
        self.d_model = ElectraForPreTraining.from_pretrained(d_model_dir)

        if self.g_model:
            self.mask = self.g_model.tokenizer.mask_token
            logger.debug('Loaded electra model: %s, spend: %.3f s.' %
                         (g_model_dir, time.time() - t1))
Ejemplo n.º 14
0
 def __init__(self,
              bert_model_dir=config.bert_model_dir,
              bert_config_path=config.bert_config_path,
              bert_model_path=config.bert_model_path):
     super(BertCorrector, self).__init__()
     self.name = 'bert_corrector'
     self.mask = '[MASK]'
     t1 = time.time()
     self.model = pipeline('fill-mask',
                           model=bert_model_path,
                           config=bert_config_path,
                           tokenizer=bert_model_dir)
     logger.debug('Loaded bert model: %s, spend: %.3f s.' %
                  (bert_model_dir, time.time() - t1))
Ejemplo n.º 15
0
 def __init__(self, bert_model_dir=config.macbert_model_dir):
     super(MacBertCorrector, self).__init__()
     self.name = 'macbert_corrector'
     t1 = time.time()
     bert_model = BertForMaskedLM.from_pretrained(bert_model_dir)
     tokenizer = BertTokenizer.from_pretrained(bert_model_dir)
     self.model = CorrectionPipeline(
         task='correction',
         model=bert_model,
         tokenizer=tokenizer,
         device=0,  # gpu device id
     )
     if self.model:
         self.mask = self.model.tokenizer.mask_token
         logger.debug('Loaded bert model: %s, spend: %.3f s.' %
                      (bert_model_dir, time.time() - t1))
Ejemplo n.º 16
0
 def __init__(self, model_dir='ernie-1.0', topN=5):
     super(ErnieCorrector, self).__init__()
     self.name = 'ernie_corrector'
     t1 = time.time()
     self.ernie_tokenizer = ErnieTokenizer.from_pretrained(model_dir)
     self.rev_dict = {v: k for k, v in self.ernie_tokenizer.vocab.items()}
     self.rev_dict[self.ernie_tokenizer.pad_id] = ''  # replace [PAD]
     self.rev_dict[self.ernie_tokenizer.sep_id] = ''  # replace [PAD]
     self.rev_dict[self.ernie_tokenizer.unk_id] = ''  # replace [PAD]
     self.cloze = ErnieCloze.from_pretrained(model_dir)
     self.cloze.eval()
     logger.debug('Loaded ernie model: %s, spend: %.3f s.' % (model_dir, time.time() - t1))
     self.mask_id = self.ernie_tokenizer.mask_id  # 3
     self.mask_token = self.rev_dict[self.mask_id]  # "[MASK]"
     logger.debug('ernie mask_id :{}, mask_token: {}'.format(self.mask_id, self.mask_token))
     self.topN = topN
Ejemplo n.º 17
0
    def _initialize_detector(self):
        t1 = time.time()
        try:
            import kenlm
        except ImportError:
            raise ImportError(
                'pycorrector dependencies are not fully installed, '
                'they are required for statistical language model.'
                'Please use "pip install kenlm" to install it.'
                'if you are Win, Please install kenlm in cgwin.')
        if not os.path.exists(self.language_model_path):
            filename = self.pre_trained_language_models.get(
                self.language_model_path, 'zh_giga.no_cna_cmn.prune01244.klm')
            url = self.pre_trained_language_models.get(filename)
            get_file(filename,
                     url,
                     extract=True,
                     cache_dir='~',
                     cache_subdir=config.USER_DATA_DIR,
                     verbose=1)
        self.lm = kenlm.Model(self.language_model_path)
        t2 = time.time()
        logger.debug('Loaded language model: %s, spend: %.3f s.' %
                     (self.language_model_path, t2 - t1))

        # 词、频数dict
        self.word_freq = self.load_word_freq_dict(self.word_freq_path)
        # 自定义混淆集
        self.custom_confusion = self._get_custom_confusion_dict(
            self.custom_confusion_path)
        # 自定义切词词典
        self.custom_word_freq = self.load_word_freq_dict(
            self.custom_word_freq_path)
        self.person_names = self.load_word_freq_dict(self.person_name_path)
        self.place_names = self.load_word_freq_dict(self.place_name_path)
        self.stopwords = self.load_word_freq_dict(self.stopwords_path)
        # 合并切词词典及自定义词典
        self.custom_word_freq.update(self.person_names)
        self.custom_word_freq.update(self.place_names)
        self.custom_word_freq.update(self.stopwords)
        self.word_freq.update(self.custom_word_freq)
        self.tokenizer = Tokenizer(dict_path=self.word_freq_path,
                                   custom_word_freq_dict=self.custom_word_freq,
                                   custom_confusion_dict=self.custom_confusion)
        t3 = time.time()
        logger.debug('Loaded dict file, spend: %.3f s.' % (t3 - t2))
        self.initialized_detector = True
Ejemplo n.º 18
0
    def __init__(self,
                 arch,
                 model_dir,
                 src_vocab_path=None,
                 trg_vocab_path=None,
                 embed_size=50,
                 hidden_size=50,
                 dropout=0.5,
                 max_length=128):
        logger.debug("device: {}".format(device))
        if arch in ['seq2seq', 'convseq2seq']:
            self.src_2_ids = load_word_dict(src_vocab_path)
            self.trg_2_ids = load_word_dict(trg_vocab_path)
            self.id_2_trgs = {v: k for k, v in self.trg_2_ids.items()}
            if arch == 'seq2seq':
                logger.debug('use seq2seq model.')
                self.model = Seq2Seq(encoder_vocab_size=len(self.src_2_ids),
                                     decoder_vocab_size=len(self.trg_2_ids),
                                     embed_size=embed_size,
                                     enc_hidden_size=hidden_size,
                                     dec_hidden_size=hidden_size,
                                     dropout=dropout).to(device)
                model_path = os.path.join(model_dir, 'seq2seq.pth')
                self.model.load_state_dict(torch.load(model_path))
                self.model.eval()
            else:
                logger.debug('use convseq2seq model.')
                trg_pad_idx = self.trg_2_ids[PAD_TOKEN]
                self.model = ConvSeq2Seq(
                    encoder_vocab_size=len(self.src_2_ids),
                    decoder_vocab_size=len(self.trg_2_ids),
                    embed_size=embed_size,
                    enc_hidden_size=hidden_size,
                    dec_hidden_size=hidden_size,
                    dropout=dropout,
                    trg_pad_idx=trg_pad_idx,
                    device=device,
                    max_length=max_length).to(device)
                model_path = os.path.join(model_dir, 'convseq2seq.pth')
                self.model.load_state_dict(torch.load(model_path))
                self.model.eval()
        elif arch == 'bertseq2seq':
            # Bert Seq2seq model
            logger.debug('use bert seq2seq model.')
            use_cuda = True if torch.cuda.is_available() else False

            # encoder_type=None, encoder_name=None, decoder_name=None
            self.model = Seq2SeqModel("bert",
                                      "{}/encoder".format(model_dir),
                                      "{}/decoder".format(model_dir),
                                      use_cuda=use_cuda)
        else:
            logger.error('error arch: {}'.format(arch))
            raise ValueError(
                "Model arch choose error. Must use one of seq2seq model.")
        self.arch = arch
        self.max_length = max_length
Ejemplo n.º 19
0
    def __init__(self, d_mdel_dir=os.path.join(pwd_path,
                                               "../data/electra_models/chinese_electra_base_discriminator_pytorch/"),
                 g_model_dir=os.path.join(pwd_path,
                                          "../data/electra_models/chinese_electra_base_generator_pytorch/"),
                 ):
        super(ElectraCorrector, self).__init__()
        self.name = 'electra_corrector'
        t1 = time.time()
        self.g_model = pipeline("fill-mask",
                                model=g_model_dir,
                                tokenizer=g_model_dir
                                )
        self.d_model = ElectraForPreTraining.from_pretrained(d_mdel_dir)

        if self.g_model:
            self.mask = self.g_model.tokenizer.mask_token
            logger.debug('Loaded electra model: %s, spend: %.3f s.' % (g_model_dir, time.time() - t1))
Ejemplo n.º 20
0
    def __init__(self,
                 vocab_path='',
                 model_path='',
                 src_seq_lens=128,
                 trg_seq_lens=128,
                 beam_size=5,
                 batch_size=1,
                 gpu_id=0):
        use_gpu = False
        if gpu_id > -1:
            os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_id)
            if torch.cuda.is_available():
                device = torch.device('cuda')
                use_gpu = True
            else:
                device = torch.device('cpu')
        else:
            device = torch.device('cpu')
        print('device:', device)
        # load vocab
        self.vocab2id = load_word_dict(vocab_path)
        self.id2vocab = {v: k for k, v in self.vocab2id.items()}
        logger.debug('Loaded vocabulary file:%s, size: %s' %
                     (vocab_path, len(self.vocab2id)))

        # load model
        start_time = time.time()
        self.model = self._create_model(self.vocab2id, device)
        if use_gpu:
            self.model.load_state_dict(torch.load(model_path))
        else:
            # 把所有的张量加载到CPU中
            self.model.load_state_dict(
                torch.load(model_path,
                           map_location=lambda storage, loc: storage))
        logger.info("Loaded model:%s, spend:%s s" %
                    (model_path, time.time() - start_time))

        self.model.eval()
        self.src_seq_lens = src_seq_lens
        self.trg_seq_lens = trg_seq_lens
        self.beam_size = beam_size
        self.batch_size = batch_size
        self.device = device
Ejemplo n.º 21
0
    def __init__(self, model_dir, vocab_path):
        super(Inference, self).__init__()
        self.name = 'bert_corrector'
        t1 = time.time()
        # device
        logger.debug("device: {}".format(device))
        model, config_dict = self._read_model(model_dir)
        # norm weight
        model.norm_embedding_weight(model.criterion.W)
        self.model = model
        self.model.eval()

        self.unk_token, self.sos_token, self.eos_token, self.pad_token, self.itos, self.stoi = self._get_config_data(
            config_dict, vocab_path)
        self.model_dir = model_dir
        self.vocab_path = vocab_path
        self.mask = "[]"
        logger.debug('Loaded deep context model: %s, spend: %.3f s.' %
                     (model_dir, time.time() - t1))
Ejemplo n.º 22
0
    def initialize_detector(self):
        t1 = time.time()
        self.lm = kenlm.Model(self.language_model_path)
        t2 = time.time()
        logger.debug('Loaded language model: %s, spend: %s s' %
                     (self.language_model_path, str(t2 - t1)))
        # 词、频数dict
        self.word_freq = self.load_word_freq_dict(self.word_freq_path)
        t3 = time.time()
        logger.debug('Loaded word freq file: %s, size: %d, spend: %s s' %
                     (self.word_freq_path, len(self.word_freq), str(t3 - t2)))
        # 自定义混淆集
        self.custom_confusion = self._get_custom_confusion_dict(
            self.custom_confusion_path)
        t4 = time.time()
        logger.debug('Loaded confusion file: %s, size: %d, spend: %s s' %
                     (self.custom_confusion_path, len(
                         self.custom_confusion), str(t4 - t3)))
        # 自定义切词词典
        self.custom_word_freq = self.load_word_freq_dict(
            self.custom_word_freq_path)
        self.person_names = self.load_word_freq_dict(self.person_name_path)
        self.place_names = self.load_word_freq_dict(self.place_name_path)
        self.stopwords = self.load_word_freq_dict(self.stopwords_path)
        # 合并切词词典及自定义词典
        self.custom_word_freq.update(self.person_names)
        self.custom_word_freq.update(self.place_names)
        self.custom_word_freq.update(self.stopwords)

        self.word_freq.update(self.custom_word_freq)
        t5 = time.time()
        logger.debug('Loaded custom word file: %s, size: %d, spend: %s s' %
                     (self.custom_confusion_path, len(
                         self.custom_word_freq), str(t5 - t4)))
        logger.debug('Loaded all word freq file done, size: %d' %
                     len(self.word_freq))
        self.tokenizer = Tokenizer(dict_path=self.word_freq_path,
                                   custom_word_freq_dict=self.custom_word_freq,
                                   custom_confusion_dict=self.custom_confusion)
        t6 = time.time()
        logger.info('Loaded dict ok, spend: %s s' % str(t6 - t1))
        self.initialized_detector = True
Ejemplo n.º 23
0
    def predict_mask_token(self, sentence, error_begin_idx, error_end_idx):
        corrected_item = sentence[error_begin_idx:error_end_idx]
        eval_features = self._convert_sentence_to_correct_features(
            sentence=sentence,
            error_begin_idx=error_begin_idx,
            error_end_idx=error_end_idx)

        for f in eval_features:
            input_ids = torch.tensor([f.input_ids])
            segment_ids = torch.tensor([f.segment_ids])
            outputs = self.model(input_ids, segment_ids)
            predictions = outputs[0]
            # confirm we were able to predict 'henson'
            masked_ids = f.mask_ids
            if masked_ids:
                for idx, i in enumerate(masked_ids):
                    predicted_index = torch.argmax(predictions[0, i]).item()
                    predicted_token = self.bert_tokenizer.convert_ids_to_tokens(
                        [predicted_index])[0]
                    logger.debug('original text is: %s' % f.input_tokens)
                    logger.debug('Mask predict is: %s' % predicted_token)
                    corrected_item = predicted_token
        return corrected_item
Ejemplo n.º 24
0
 def set_custom_confusion_dict(self, path):
     self.check_detector_initialized()
     self.custom_confusion = self._get_custom_confusion_dict(path)
     logger.debug('Loaded confusion path: %s, size: %d' %
                  (path, len(self.custom_confusion)))
Ejemplo n.º 25
0
 def set_language_model_path(self, path):
     self.check_detector_initialized()
     import kenlm
     self.lm = kenlm.Model(path)
     logger.debug('Loaded language model: %s' % path)
Ejemplo n.º 26
0
        logger.info("JAX version {}, Flax: available".format(jax.__version__))
        logger.info("Flax available: {}".format(flax))
        _flax_available = True
    else:
        _flax_available = False
except ImportError:
    _flax_available = False  # pylint: disable=invalid-name

try:
    import datasets  # noqa: F401

    # Check we're not importing a "datasets" directory somewhere
    _datasets_available = hasattr(datasets, "__version__") and hasattr(datasets, "load_dataset")
    if _datasets_available:
        logger.debug(f"Successfully imported datasets version {datasets.__version__}")
    else:
        logger.debug("Imported a datasets object but this doesn't seem to be the 🤗 datasets library.")

except ImportError:
    _datasets_available = False

try:
    from torch.hub import _get_torch_home

    torch_cache_home = _get_torch_home()
except ImportError:
    torch_cache_home = os.path.expanduser(
        os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
    )
Ejemplo n.º 27
0
 def _init(self):
     self.WORDS = Counter(words(open(self.path).read()))
     logger.debug("load en spell data: %s, size: %d" %
                  (self.path, len(self.WORDS)))
Ejemplo n.º 28
0
def main():
    cfg = args_parse()

    # 如果不存在训练文件则先处理数据
    if not os.path.exists(cfg.DATASETS.TRAIN):
        logger.debug('preprocess data')
        preprocess.main()
    logger.info(f'load model, model arch: {cfg.MODEL.NAME}')
    tokenizer = BertTokenizer.from_pretrained(cfg.MODEL.BERT_CKPT)
    collator = DataCollator(tokenizer=tokenizer)
    # 加载数据
    train_loader, valid_loader, test_loader = make_loaders(
        collator,
        train_path=cfg.DATASETS.TRAIN,
        valid_path=cfg.DATASETS.VALID,
        test_path=cfg.DATASETS.TEST,
        batch_size=cfg.SOLVER.BATCH_SIZE,
        num_workers=4)
    if cfg.MODEL.NAME == 'softmaskedbert4csc':
        model = SoftMaskedBert4Csc(cfg, tokenizer)
    elif cfg.MODEL.NAME == 'macbert4csc':
        model = MacBert4Csc(cfg, tokenizer)
    else:
        raise ValueError("model not found.")
    # 加载之前保存的模型,继续训练
    if cfg.MODEL.WEIGHTS and os.path.exists(cfg.MODEL.WEIGHTS):
        model.load_from_checkpoint(checkpoint_path=cfg.MODEL.WEIGHTS,
                                   cfg=cfg,
                                   map_location=device,
                                   tokenizer=tokenizer)
    # 配置模型保存参数
    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
    ckpt_callback = ModelCheckpoint(monitor='val_loss',
                                    dirpath=cfg.OUTPUT_DIR,
                                    filename='{epoch:02d}-{val_loss:.2f}',
                                    save_top_k=1,
                                    mode='min')
    # 训练模型
    logger.info('train model ...')
    trainer = pl.Trainer(
        max_epochs=cfg.SOLVER.MAX_EPOCHS,
        gpus=None if device == torch.device('cpu') else cfg.MODEL.GPU_IDS,
        accumulate_grad_batches=cfg.SOLVER.ACCUMULATE_GRAD_BATCHES,
        callbacks=[ckpt_callback])
    # 进行训练
    # train_loader中有数据
    torch.autograd.set_detect_anomaly(True)
    if 'train' in cfg.MODE and train_loader and len(train_loader) > 0:
        if valid_loader and len(valid_loader) > 0:
            trainer.fit(model, train_loader, valid_loader)
        else:
            trainer.fit(model, train_loader)
        logger.info('train model done.')
    # 模型转为transformers可加载
    if ckpt_callback and len(ckpt_callback.best_model_path) > 0:
        ckpt_path = ckpt_callback.best_model_path
    elif cfg.MODEL.WEIGHTS and os.path.exists(cfg.MODEL.WEIGHTS):
        ckpt_path = cfg.MODEL.WEIGHTS
    else:
        ckpt_path = ''
    logger.info(f'ckpt_path: {ckpt_path}')
    if ckpt_path and os.path.exists(ckpt_path):
        model.load_state_dict(torch.load(ckpt_path)['state_dict'])
        # 先保存原始transformer bert model
        tokenizer.save_pretrained(cfg.OUTPUT_DIR)
        bert = BertForMaskedLM.from_pretrained(cfg.MODEL.BERT_CKPT)
        bert.save_pretrained(cfg.OUTPUT_DIR)
        state_dict = torch.load(ckpt_path)['state_dict']
        new_state_dict = OrderedDict()
        if cfg.MODEL.NAME in ['macbert4csc']:
            for k, v in state_dict.items():
                if k.startswith('bert.'):
                    new_state_dict[k[5:]] = v
        else:
            new_state_dict = state_dict
        # 再保存finetune训练后的模型文件,替换原始的pytorch_model.bin
        torch.save(new_state_dict,
                   os.path.join(cfg.OUTPUT_DIR, 'pytorch_model.bin'))
    # 进行测试的逻辑同训练
    if 'test' in cfg.MODE and test_loader and len(test_loader) > 0:
        trainer.test(model, test_loader)