예제 #1
0
    def download(key):
        if key in _W2V_PRETRAIN_MODELS:
            url = _W2V_PRETRAIN_MODELS[key]
        elif key.startswith("http://") or key.startswith("https://"):
            url = key
        else:
            return None
        model_file_name = os.path.basename(url)
        this_dir = os.path.dirname(__file__)
        extract_path = os.path.join(this_dir, model_file_name)
        if not os.path.exists(extract_path):
            logging.info(f"Downloading Wav2vec2.0: {key}")
            download_with_tqdm(url, extract_path)

        return extract_path
예제 #2
0
 def download(key):
     if key in _BERT_PRETRAIN_MODELS:
         url = _BERT_PRETRAIN_MODELS[key]
     elif key.startswith("http://") or key.startswith("https://"):
         url = key
     else:
         return None
     bert_name = os.path.basename(url).split(".")[0]
     this_dir = os.path.dirname(__file__)
     extract_path = os.path.join(this_dir, bert_name)
     if not os.path.exists(extract_path):
         logging.info(f"Downloading google bert: {key}")
         tarball = os.path.join(this_dir, os.path.basename(url))
         download_with_tqdm(url, tarball)
         tf.io.gfile.makedirs(extract_path)
         with zipfile.ZipFile(tarball) as zip_ref:
             zip_ref.extractall(extract_path)
         os.remove(tarball)
     if os.path.isdir(os.path.join(extract_path, bert_name)):
         return os.path.join(extract_path, bert_name)
     return extract_path
예제 #3
0
 def download(key):
     if key in _GPT2_PRETRAIN_MODELS:
         url = _GPT2_PRETRAIN_MODELS[key]
     elif key.startswith("http://") or key.startswith("https://"):
         url = key
     else:
         return None
     logging.info(f"Downloading openai gpt2: {key}")
     this_dir = os.path.dirname(__file__)
     model_dir = os.path.join(this_dir, f"GPT2_{key}")
     if not os.path.exists(model_dir):
         os.makedirs(model_dir)
     for filename in [
             'checkpoint', 'encoder.json', 'hparams.json',
             'model.ckpt.data-00000-of-00001', 'model.ckpt.index',
             'model.ckpt.meta', 'vocab.bpe'
     ]:
         this_url = url + "/" + filename
         save_filename = os.path.join(model_dir, filename)
         if not os.path.exists(save_filename):
             logging.info(f"Downloading {this_url}")
             download_with_tqdm(this_url, save_filename)
     return model_dir
예제 #4
0
 def _lazy_init(self):
     codes = self._codes
     from_local = False
     if codes.startswith("hdfs://"):
         local_path = os.path.join(os.path.dirname(__file__),
                                   "spm{}.model".format(time.time()))
         logging.info("Copying spm model: {} to local: {}".format(
             codes, local_path))
         tf.io.gfile.copy(codes, local_path, overwrite=True)
         codes = local_path
         from_local = True
     elif codes.startswith("http"):
         local_path = os.path.join(os.path.dirname(__file__),
                                   "spm{}.model".format(time.time()))
         logging.info(
             "Downloading spm model to local: {}".format(local_path))
         download_with_tqdm(codes, local_path)
         codes = local_path
         from_local = True
     status = self._sp.Load(codes)
     assert status, "Fail to load spm model: {}".format(codes)
     if from_local:
         tf.io.gfile.remove(codes)
     self._built = True
예제 #5
0
def download_to(output_dir):
    # download training data
    training_srcs = []
    training_trgs = []
    for traindata in _TRAIN_DATA_RESOURCES:
        src = os.path.join(output_dir, traindata["SRC"])
        trg = os.path.join(output_dir, traindata["TRG"])
        training_srcs.append(src)
        training_trgs.append(trg)
        if os.path.exists(src) and os.path.exists(trg):
            continue
        tar_filename = os.path.join(output_dir, traindata["TAR"])
        if not os.path.exists(tar_filename):
            download_with_tqdm(traindata["URL"], tar_filename)
        with tarfile.open(tar_filename, "r:*") as tar:
            for tarinfo in tar:
                if tarinfo.name.endswith(
                        traindata["SRC"]) or tarinfo.name.endswith(
                            traindata["TRG"]):
                    tar.extract(tarinfo, output_dir)

    with tf.io.gfile.GFile(os.path.join(output_dir, "train.en.txt"),
                           "w") as fw_src:
        with tf.io.gfile.GFile(os.path.join(output_dir, "train.de.txt"),
                               "w") as fw_trg:
            for src, trg in zip(training_srcs, training_trgs):
                with tf.io.gfile.GFile(src, "r") as f_src, tf.io.gfile.GFile(
                        trg, "r") as f_trg:
                    for s, t in zip(f_src, f_trg):
                        fw_src.write(" ".join(s.strip().split()) + "\n")
                        fw_trg.write(" ".join(t.strip().split()) + "\n")
    # download dev data
    dev_tar_filename = os.path.join(output_dir, _DEV_DATA_RESOURCE["TAR"])
    dev_src = os.path.join(output_dir, _DEV_DATA_RESOURCE["SRC"])
    dev_trg = os.path.join(output_dir, _DEV_DATA_RESOURCE["TRG"])
    if not (os.path.exists(dev_src) and os.path.exists(dev_trg)):
        if not os.path.exists(dev_tar_filename):
            download_with_tqdm(_DEV_DATA_RESOURCE["URL"], dev_tar_filename)
        with tarfile.open(dev_tar_filename, "r:*") as tar:
            for tarinfo in tar:
                if (tarinfo.name.endswith(_DEV_DATA_RESOURCE["SRC"])
                        or tarinfo.name.endswith(_DEV_DATA_RESOURCE["TRG"])):
                    tar.extract(tarinfo, output_dir)
    with open(os.path.join(output_dir, "newstest2013.en.txt"), "w") as fw_src:
        with open(os.path.join(output_dir, "newstest2013.de.txt"),
                  "w") as fw_trg:
            with tf.io.gfile.GFile(dev_src) as f_src, tf.io.gfile.GFile(
                    dev_trg) as f_trg:
                for s, t in zip(f_src, f_trg):
                    fw_src.write(s.strip() + "\n")
                    fw_trg.write(t.strip() + "\n")

    # download test data
    test_tar_filename = os.path.join(output_dir, _TEST_DATA_RESOURCE["TAR"])
    test_src = os.path.join(output_dir, _TEST_DATA_RESOURCE["SRC"])
    test_trg = os.path.join(output_dir, _TEST_DATA_RESOURCE["TRG"])
    if not (os.path.exists(test_src) and os.path.exists(test_trg)):
        if not os.path.exists(test_tar_filename):
            download_with_tqdm(_TEST_DATA_RESOURCE["URL"], test_tar_filename)
        with tarfile.open(test_tar_filename, "r:*") as tar:
            for tarinfo in tar:
                if (tarinfo.name.endswith(_TEST_DATA_RESOURCE["SRC"])
                        or tarinfo.name.endswith(_TEST_DATA_RESOURCE["TRG"])):
                    tar.extract(tarinfo, output_dir)
    with open(os.path.join(output_dir, "newstest2014.en.txt"), "w") as fw_src:
        with open(os.path.join(output_dir, "newstest2014.de.txt"),
                  "w") as fw_trg:
            with tf.io.gfile.GFile(test_src) as f_src, tf.io.gfile.GFile(
                    test_trg) as f_trg:
                for s, t in zip(f_src, f_trg):
                    if s.startswith("<seg id"):
                        fw_src.write(_wrapper_xml(s.strip()) + "\n")
                        fw_trg.write(_wrapper_xml(t.strip()) + "\n")
    return training_srcs, training_trgs