예제 #1
0
def mixed_segmentation(in_str, rm_punc=False):
    nltk_path = os.path.join(THIRD_PARTY_HOME, "nltk_data")
    tokenizers_path = os.path.join(nltk_path, "tokenizers")
    punkt_path = os.path.join(tokenizers_path, "punkt")
    if not os.path.exists(punkt_path):
        default_downloader.download_file_and_uncompress(
            url=_PUNKT_URL, save_path=tokenizers_path, print_progress=True)
    nltk.data.path.append(nltk_path)

    in_str = str(in_str).lower().strip()
    segs_out = []
    temp_str = ""
    sp_char = [
        '-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', ',', '。', ':',
        '?', '!', '“', '”', ';', '’', '《', '》', '……', '·', '、', '「', '」', '(',
        ')', '-', '~', '『', '』'
    ]
    for char in in_str:
        if rm_punc and char in sp_char:
            continue
        if re.search(r'[\u4e00-\u9fa5]', char) or char in sp_char:
            if temp_str != "":
                ss = nltk.word_tokenize(temp_str)
                segs_out.extend(ss)
                temp_str = ""
            segs_out.append(char)
        else:
            temp_str += char

    # handling last part
    if temp_str != "":
        ss = nltk.word_tokenize(temp_str)
        segs_out.extend(ss)

    return segs_out
예제 #2
0
파일: glue.py 프로젝트: VVJY/PaddleHub
    def __init__(self, sub_dataset='SST-2'):
        # sub_dataset : CoLA, MNLI, MRPC, QNLI, QQP, RTE, SST-2, STS-B
        if sub_dataset not in [
                'CoLA', 'MNLI', 'MNLI_m', 'MNLI_mm', 'MRPC', 'QNLI', 'QQP',
                'RTE', 'SST-2', 'STS-B'
        ]:
            raise Exception(
                sub_dataset +
                " is not in GLUE benchmark. Please confirm the data set")
        self.mismatch = False
        if sub_dataset == 'MNLI_mm':
            sub_dataset = 'MNLI'
            self.mismatch = True
        elif sub_dataset == 'MNLI_m':
            sub_dataset = 'MNLI'
        self.sub_dataset = sub_dataset
        self.dataset_dir = os.path.join(DATA_HOME, "glue_data")

        if not os.path.exists(self.dataset_dir):
            ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
                url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
        else:
            logger.info("Dataset {} already cached.".format(self.dataset_dir))

        self._load_train_examples()
        self._load_dev_examples()
        self._load_test_examples()
        self._load_predict_examples()
예제 #3
0
 def _init_with_url(self, url):
     utils.check_url(url)
     result, tips, module_dir = default_downloader.download_file_and_uncompress(
         url, save_path=".")
     if not result:
         logger.error(tips)
         exit(1)
     self._init_with_module_file(module_dir)
예제 #4
0
파일: squad.py 프로젝트: jayhenry/PaddleHub
 def __init__(self, version_2_with_negative=False):
     self.dataset_dir = os.path.join(DATA_HOME, "squad_data")
     if not os.path.exists(self.dataset_dir):
         ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
             url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
     else:
         logger.info("Dataset {} already cached.".format(self.dataset_dir))
     self.version_2_with_negative = version_2_with_negative
     self._load_train_examples(version_2_with_negative, if_has_answer=True)
     self._load_dev_examples(version_2_with_negative, if_has_answer=True)
예제 #5
0
 def _download_dataset(self, dataset_path, url):
     if not os.path.exists(dataset_path):
         result, tips, dataset_path = default_downloader.download_file_and_uncompress(
             url=url,
             save_path=hub.common.dir.DATA_HOME,
             print_progress=True,
             replace=True)
         if not result:
             print(tips)
             exit()
     return dataset_path
예제 #6
0
    def __init__(self):
        self.dataset_dir = os.path.join(DATA_HOME, "msra_ner")
        if not os.path.exists(self.dataset_dir):
            ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
                url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
        else:
            logger.info("Dataset {} already cached.".format(self.dataset_dir))

        self._load_train_examples()
        self._load_test_examples()
        self._load_dev_examples()
예제 #7
0
 def _download_dataset(self, dataset_path, url):
     if not os.path.exists(dataset_path):
         result, tips, dataset_path = default_downloader.download_file_and_uncompress(
             url=url,
             save_path=hub.common.dir.DATA_HOME,
             print_progress=True,
             replace=True)
         if not result:
             raise Exception(tips)
     else:
         logger.info("Dataset {} already cached.".format(dataset_path))
     return dataset_path
예제 #8
0
    def __init__(self, language='zh'):
        if language not in [
                "ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw",
                "th", "tr", "ur", "vi", "zh"
        ]:
            raise Exception(language +
                            "is not in XNLI. Please confirm the language")
        self.language = language
        self.dataset_dir = os.path.join(DATA_HOME, "XNLI-lan")

        if not os.path.exists(self.dataset_dir):
            ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
                url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
        else:
            logger.info("Dataset {} already cached.".format(self.dataset_dir))

        self._load_train_examples()
        self._load_test_examples()
        self._load_dev_examples()
예제 #9
0
for dependency in ["ruamel", "parakeet", "soundfile", "librosa"]:
    if not importlib.util.find_spec(dependency):
        lack_dependency.append(dependency)

# Accelerate NLTK package download via paddlehub. 'import parakeet' will use the package.
_PUNKT_URL = "https://paddlehub.bj.bcebos.com/paddlehub-thirdparty/punkt.tar.gz"
_CMUDICT_URL = "https://paddlehub.bj.bcebos.com/paddlehub-thirdparty/cmudict.tar.gz"
nltk_path = os.path.join(THIRD_PARTY_HOME, "nltk_data")
tokenizers_path = os.path.join(nltk_path, "tokenizers")
corpora_path = os.path.join(nltk_path, "corpora")
punkt_path = os.path.join(tokenizers_path, "punkt")
cmudict_path = os.path.join(corpora_path, "cmudict")

if not os.path.exists(punkt_path):
    default_downloader.download_file_and_uncompress(url=_PUNKT_URL,
                                                    save_path=tokenizers_path,
                                                    print_progress=True)
if not os.path.exists(cmudict_path):
    default_downloader.download_file_and_uncompress(url=_CMUDICT_URL,
                                                    save_path=corpora_path,
                                                    print_progress=True)
nltk.data.path.append(nltk_path)

if not lack_dependency:
    import soundfile as sf
    import librosa
    from ruamel import yaml
    from parakeet.models.fastspeech.fastspeech import FastSpeech as FastSpeechModel
    from parakeet.g2p.en import text_to_sequence
    from parakeet.models.transformer_tts.utils import *
    from parakeet.utils import io