示例#1
0
    def _get_data(self, mode, **kwargs):
        default_root = os.path.join(DATA_HOME, self.__class__.__name__)
        src_filename, tgt_filename, src_data_hash, tgt_data_hash = self.SPLITS[
            mode]
        src_fullname = os.path.join(default_root, src_filename)
        tgt_fullname = os.path.join(default_root, tgt_filename)

        (bpe_vocab_filename, bpe_vocab_hash), (sub_vocab_filename,
                                               sub_vocab_hash) = self.VOCAB_INFO
        bpe_vocab_fullname = os.path.join(default_root, bpe_vocab_filename)
        sub_vocab_fullname = os.path.join(default_root, sub_vocab_filename)

        if (not os.path.exists(src_fullname) or
            (src_data_hash and not md5file(src_fullname) == src_data_hash)) or (
                not os.path.exists(tgt_fullname) or
                (tgt_data_hash and
                 not md5file(tgt_fullname) == tgt_data_hash)) or (
                     not os.path.exists(bpe_vocab_fullname) or
                     (bpe_vocab_hash and
                      not md5file(bpe_vocab_fullname) == bpe_vocab_hash)) or (
                          not os.path.exists(sub_vocab_fullname) or
                          (sub_vocab_hash and
                           not md5file(sub_vocab_fullname) == sub_vocab_hash)):
            get_path_from_url(self.URL, default_root, self.MD5)

        return src_fullname, tgt_fullname
示例#2
0
def download(url, md5sum, target_dir):
    """Download file from url to target_dir, and check md5sum."""
    if not os.path.exists(target_dir): os.makedirs(target_dir)
    filepath = os.path.join(target_dir, url.split("/")[-1])
    if not (os.path.exists(filepath) and md5file(filepath) == md5sum):
        print("Downloading %s ..." % url)
        os.system("wget -c " + url + " -P " + target_dir)
        print("\nMD5 Chesksum %s ..." % filepath)
        if not md5file(filepath) == md5sum:
            raise RuntimeError("MD5 checksum failed.")
    else:
        print("File exists, skip downloading. (%s)" % filepath)
    return filepath
    def _get_data(self, mode, **kwargs):
        default_root = os.path.join(DATA_HOME, self.__class__.__name__)
        filename, data_hash = self.SPLITS[mode]
        fullname = os.path.join(default_root, filename)
        vocab_filename, vocab_hash = self.VOCAB_INFO
        vocab_fullname = os.path.join(default_root, vocab_filename)

        if (not os.path.exists(fullname)
            ) or (data_hash and not md5file(fullname) == data_hash) or (
                not os.path.exists(vocab_fullname) or
                (vocab_hash and not md5file(vocab_fullname) == vocab_hash)):

            get_path_from_url(self.URL, default_root, self.MD5)

        return fullname
示例#4
0
 def _get_data(self, mode):
     """ Check and download Dataset """
     dl_paths = {}
     version = self.config.get("version", "3.0.0")
     if version not in ["1.0.0", "2.0.0", "3.0.0"]:
         raise ValueError("Unsupported version: %s" % version)
     dl_paths["version"] = version
     default_root = os.path.join(DATA_HOME, self.__class__.__name__)
     for k, v in self.cnn_dailymail.items():
         dir_path = os.path.join(default_root, k)
         if not os.path.exists(dir_path):
             get_path_from_url(v["url"], default_root, v["md5"])
         unique_endpoints = _get_unique_endpoints(ParallelEnv()
                                                  .trainer_endpoints[:])
         if ParallelEnv().current_endpoint in unique_endpoints:
             file_num = len(os.listdir(os.path.join(dir_path, "stories")))
             if file_num != v["file_num"]:
                 logger.warning(
                     "Number of %s stories is %d != %d, decompress again." %
                     (k, file_num, v["file_num"]))
                 shutil.rmtree(os.path.join(dir_path, "stories"))
                 _decompress(
                     os.path.join(default_root, os.path.basename(v["url"])))
         dl_paths[k] = dir_path
     filename, url, data_hash = self.SPLITS[mode]
     fullname = os.path.join(default_root, filename)
     if not os.path.exists(fullname) or (data_hash and
                                         not md5file(fullname) == data_hash):
         get_path_from_url(url, default_root, data_hash)
     dl_paths[mode] = fullname
     return dl_paths
示例#5
0
    def _download_data(cls, mode="train", root=None):
        """Download dataset"""
        default_root = os.path.join(DATA_HOME, 'machine_translation',
                                    cls.__name__)
        src_filename, tgt_filename, src_data_hash, tgt_data_hash = cls.SPLITS[
            mode]

        filename_list = [
            src_filename, tgt_filename, cls.VOCAB_INFO[0], cls.VOCAB_INFO[1]
        ]
        fullname_list = []
        for filename in filename_list:
            fullname = os.path.join(
                default_root, filename) if root is None else os.path.join(
                    os.path.expanduser(root), filename)
            fullname_list.append(fullname)

        data_hash_list = [
            src_data_hash, tgt_data_hash, cls.VOCAB_INFO[2], cls.VOCAB_INFO[3]
        ]
        for i, fullname in enumerate(fullname_list):
            if not os.path.exists(fullname) or (
                    data_hash_list[i]
                    and not md5file(fullname) == data_hash_list[i]):
                if root is not None:  # not specified, and no need to warn
                    warnings.warn(
                        'md5 check failed for {}, download {} data to {}'.
                        format(filename, cls.__name__, default_root))
                path = get_path_from_url(cls.URL, default_root, cls.MD5)
                return default_root
        return root if root is not None else default_root
示例#6
0
文件: squad.py 项目: jandyu/models-1
 def _get_data(self, root, mode, **kwargs):
     default_root = os.path.join(DATA_HOME, self.__class__.__name__)
     if self.version_2_with_negative:
         filename, data_hash = self.SPLITS['2.0'][mode]
     else:
         filename, data_hash = self.SPLITS['1.1'][mode]
     fullname = os.path.join(default_root,
                             filename) if root is None else os.path.join(
                                 os.path.expanduser(root), filename)
     if not os.path.exists(fullname) or (
             data_hash and not md5file(fullname) == data_hash):
         if root is not None:  # not specified, and no need to warn
             warnings.warn(
                 'md5 check failed for {}, download {} data to {}'.format(
                     filename, self.__class__.__name__, default_root))
         if mode == 'train':
             if self.version_2_with_negative:
                 fullname = get_path_from_url(
                     self.TRAIN_DATA_URL_V2,
                     os.path.join(default_root, 'v2'))
             else:
                 fullname = get_path_from_url(
                     self.TRAIN_DATA_URL_V1,
                     os.path.join(default_root, 'v1'))
         elif mode == 'dev':
             if self.version_2_with_negative:
                 fullname = get_path_from_url(
                     self.DEV_DATA_URL_V2, os.path.join(default_root, 'v2'))
             else:
                 fullname = get_path_from_url(
                     self.DEV_DATA_URL_V1, os.path.join(default_root, 'v1'))
     self.full_path = fullname
示例#7
0
def download_file(save_dir, filename, url, md5=None, task=None):
    """
    Download the file from the url to specified directory. 
    Check md5 value when the file is exists, if the md5 value is the same as the existed file, just use 
    the older file, if not, will download the file from the url.

    Args:
        save_dir(string): The specified directory saving the file.
        filename(string): The specified filename saving the file.
        url(string): The url downling the file.
        md5(string, optional): The md5 value that checking the version downloaded. 
    """
    logger.disable()
    global DOWNLOAD_CHECK
    if not DOWNLOAD_CHECK:
        DOWNLOAD_CHECK = True
        checker = DownloaderCheck(task)
        checker.start()
        checker.join()
    fullname = os.path.join(save_dir, filename)
    if os.path.exists(fullname):
        if md5 and (not md5file(fullname) == md5):
            get_path_from_url(url, save_dir, md5)
    else:
        get_path_from_url(url, save_dir, md5)
    logger.enable()
    return fullname
示例#8
0
    def _get_data(self, mode, **kwargs):
        """Downloads dataset."""
        default_root = os.path.join(DATA_HOME, self.__class__.__name__)
        filename, data_hash, url, zipfile_hash = self.SPLITS[mode]
        fullname = os.path.join(default_root, filename)
        if mode == 'train':
            if not os.path.exists(fullname):
                get_path_from_url(url, default_root, zipfile_hash)
            unique_endpoints = _get_unique_endpoints(
                ParallelEnv().trainer_endpoints[:])
            if ParallelEnv().current_endpoint in unique_endpoints:
                file_num = len(os.listdir(fullname))
                if file_num != len(ALL_LANGUAGES):
                    logger.warning(
                        "Number of train files is %d != %d, decompress again."
                        % (file_num, len(ALL_LANGUAGES)))
                    shutil.rmtree(fullname)
                    _decompress(
                        os.path.join(default_root, os.path.basename(url)))
        else:
            if not os.path.exists(fullname) or (
                    data_hash and not md5file(fullname) == data_hash):
                get_path_from_url(url, default_root, zipfile_hash)

        return fullname
示例#9
0
    def _get_data(self, mode, **kwargs):
        default_root = DATA_HOME
        filename, data_hash = self.SPLITS[mode]
        fullname = os.path.join(default_root, filename)
        if not os.path.exists(fullname) or (data_hash and
                                            not md5file(fullname) == data_hash):
            fullname = os.path.join(default_root, filename)

        return fullname
示例#10
0
 def _get_data(self, mode, **kwargs):
     builder_config = self.BUILDER_CONFIGS[self.name]
     default_root = os.path.join(DATA_HOME, self.__class__.__name__)
     filename, url, data_hash = builder_config['splits'][mode]
     fullname = os.path.join(default_root, filename)
     if not os.path.exists(fullname) or (
             data_hash and not md5file(fullname) == data_hash):
         get_path_from_url(url, default_root, data_hash)
     return fullname
示例#11
0
    def _get_data(self, mode, **kwargs):
        default_root = os.path.join(DATA_HOME, self.__class__.__name__)
        filename, data_hash = self.SPLITS[mode]
        fullname = os.path.join(default_root, filename)
        if not os.path.exists(fullname) or (
                data_hash and not md5file(fullname) == data_hash):
            get_path_from_url(self.URL, default_root)

        return fullname
示例#12
0
    def _get_data(self, mode, **kwargs):
        """Downloads dataset."""
        default_root = DATA_HOME
        filename, data_hash = self.SPLITS[mode]
        fullname = os.path.join(default_root, filename)
        if not os.path.exists(fullname) or (
                data_hash and not md5file(fullname) == data_hash):
            path = get_path_from_url(self.URL, default_root, self.MD5)
            fullname = os.path.join(default_root, filename)

        return fullname
示例#13
0
 def _get_data(self, mode, **kwargs):
     default_root = os.path.join(DATA_HOME, self.__class__.__name__, mode)
     meta_info_list = self.SPLITS[mode]
     fullnames = []
     for meta_info in meta_info_list:
         filename, data_hash, URL = meta_info
         fullname = os.path.join(default_root, filename)
         if not os.path.exists(fullname) or (
                 data_hash and not md5file(fullname) == data_hash):
             get_path_from_url(URL, default_root)
         fullnames.append(fullname)
     return fullnames
示例#14
0
    def _get_data(self, mode, **kwargs):
        """Downloads dataset."""
        builder_config = self.BUILDER_CONFIGS[self.name]
        default_root = os.path.join(DATA_HOME, 'SE-ABSA16_PHNS')
        filename, data_hash, _, _ = builder_config['splits'][mode]
        fullname = os.path.join(default_root, filename)
        if not os.path.exists(fullname) or (
                data_hash and not md5file(fullname) == data_hash):
            url = builder_config['url']
            md5 = builder_config['md5']
            get_path_from_url(url, DATA_HOME, md5)

        return fullname
示例#15
0
    def test_download_url(self):
        LABEL_URL = 'http://paddlemodels.bj.bcebos.com/flowers/imagelabels.mat'
        LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'

        catch_exp = False
        try:
            download(LABEL_URL, 'flowers', LABEL_MD5)
        except Exception as e:
            catch_exp = True

        self.assertTrue(catch_exp == False)

        file_path = DATA_HOME + "/flowers/imagelabels.mat"

        self.assertTrue(os.path.exists(file_path))
        self.assertTrue(md5file(file_path), LABEL_MD5)
示例#16
0
    def _get_data(self, root, mode, **kwargs):
        default_root = os.path.join(DATA_HOME, 'DuReader')

        filename, data_hash = self.SPLITS[mode]

        fullname = os.path.join(default_root,
                                filename) if root is None else os.path.join(
                                    os.path.expanduser(root), filename)
        if not os.path.exists(fullname) or (
                data_hash and not md5file(fullname) == data_hash):
            if root is not None:  # not specified, and no need to warn
                warnings.warn(
                    'md5 check failed for {}, download {} data to {}'.format(
                        filename, self.__class__.__name__, default_root))

            get_path_from_url(self.DATA_URL, default_root)

        self.full_path = fullname
示例#17
0
文件: lcqmc.py 项目: wbj0110/models
 def _get_data(self, root, mode, **kwargs):
     default_root = DATA_HOME
     filename, data_hash, field_indices, num_discard_samples = self.SPLITS[
         mode]
     fullname = os.path.join(default_root,
                             filename) if root is None else os.path.join(
                                 os.path.expanduser(root), filename)
     if not os.path.exists(fullname) or (
             data_hash and not md5file(fullname) == data_hash):
         if root is not None:  # not specified, and no need to warn
             warnings.warn(
                 'md5 check failed for {}, download {} data to {}'.format(
                     filename, self.__class__.__name__, default_root))
         path = get_path_from_url(self.URL, default_root, self.MD5)
         fullname = os.path.join(default_root, filename)
     super(LCQMC, self).__init__(fullname,
                                 field_indices=field_indices,
                                 num_discard_samples=num_discard_samples,
                                 **kwargs)
示例#18
0
def download_file(save_dir, filename, url, md5=None):
    """
    Download the file from the url to specified directory. 
    Check md5 value when the file is exists, if the md5 value is the same as the existed file, just use 
    the older file, if not, will download the file from the url.

    Args:
        save_dir(string): The specified directory saving the file.
        filename(string): The specified filename saving the file.
        url(string): The url downling the file.
        md5(string, optional): The md5 value that checking the version downloaded. 
    """
    fullname = os.path.join(save_dir, filename)
    if os.path.exists(fullname):
        if md5 and (not md5file(fullname) == md5):
            logger.disable()
            get_path_from_url(url, save_dir, md5)
    else:
        logger.info("Downloading {} from {}".format(filename, url))
        logger.disable()
        get_path_from_url(url, save_dir, md5)
    logger.enable()
    return fullname
示例#19
0
    def _check_task_files(self):
        """
        Check files required by the task.
        """
        for file_id, file_name in self.resource_files_names.items():
            path = os.path.join(self._task_path, file_name)
            url = self.resource_files_urls[self.model][file_id][0]
            md5 = self.resource_files_urls[self.model][file_id][1]

            downloaded = True
            if not os.path.exists(path):
                downloaded = False
            else:
                if not self._custom_model:
                    if os.path.exists(path):
                        # Check whether the file is updated
                        if not md5file(path) == md5:
                            downloaded = False
                            if file_id == "model_state":
                                self._param_updated = True
                    else:
                        downloaded = False
            if not downloaded:
                download_file(self._task_path, file_name, url, md5)
示例#20
0
    def _get_data(self, mode, **kwargs):
        builder_config = self.BUILDER_CONFIGS[self.name]
        if self.name != 'mrpc':
            default_root = os.path.join(DATA_HOME, self.__class__.__name__)
            filename, data_hash, _, _ = builder_config['splits'][mode]
            fullname = os.path.join(default_root, filename)
            if not os.path.exists(fullname) or (
                    data_hash and not md5file(fullname) == data_hash):
                get_path_from_url(builder_config['url'], default_root,
                                  builder_config['md5'])

        else:
            default_root = os.path.join(DATA_HOME, self.__class__.__name__)
            filename, data_hash, _, _ = builder_config['splits'][mode]
            fullname = os.path.join(default_root, filename)
            if not os.path.exists(fullname) or (
                    data_hash and not md5file(fullname) == data_hash):
                if mode in ('train', 'dev'):
                    dev_id_path = get_path_from_url(
                        builder_config['url']['dev_id'],
                        os.path.join(default_root, 'MRPC'),
                        builder_config['md5']['dev_id'])
                    train_data_path = get_path_from_url(
                        builder_config['url']['train_data'],
                        os.path.join(default_root, 'MRPC'),
                        builder_config['md5']['train_data'])
                    # read dev data ids
                    dev_ids = []
                    print(dev_id_path)
                    with open(dev_id_path, encoding='utf-8') as ids_fh:
                        for row in ids_fh:
                            dev_ids.append(row.strip().split('\t'))

                    # generate train and dev set
                    train_path = os.path.join(default_root, 'MRPC',
                                              'train.tsv')
                    dev_path = os.path.join(default_root, 'MRPC', 'dev.tsv')
                    with open(train_data_path, encoding='utf-8') as data_fh:
                        with open(train_path, 'w',
                                  encoding='utf-8') as train_fh:
                            with open(dev_path, 'w',
                                      encoding='utf8') as dev_fh:
                                header = data_fh.readline()
                                train_fh.write(header)
                                dev_fh.write(header)
                                for row in data_fh:
                                    label, id1, id2, s1, s2 = row.strip(
                                    ).split('\t')
                                    example = '%s\t%s\t%s\t%s\t%s\n' % (
                                        label, id1, id2, s1, s2)
                                    if [id1, id2] in dev_ids:
                                        dev_fh.write(example)
                                    else:
                                        train_fh.write(example)

                else:
                    test_data_path = get_path_from_url(
                        builder_config['url']['test_data'],
                        os.path.join(default_root, 'MRPC'),
                        builder_config['md5']['test_data'])
                    test_path = os.path.join(default_root, 'MRPC', 'test.tsv')
                    with open(test_data_path, encoding='utf-8') as data_fh:
                        with open(test_path, 'w', encoding='utf-8') as test_fh:
                            header = data_fh.readline()
                            test_fh.write(
                                'index\t#1 ID\t#2 ID\t#1 String\t#2 String\n')
                            for idx, row in enumerate(data_fh):
                                label, id1, id2, s1, s2 = row.strip().split(
                                    '\t')
                                test_fh.write('%d\t%s\t%s\t%s\t%s\n' %
                                              (idx, id1, id2, s1, s2))

        return fullname
示例#21
0
    def _get_data(self, root, segment, **kwargs):
        default_root = os.path.join(DATA_HOME, 'glue')
        filename, data_hash, field_indices, num_discard_samples = self.SEGMENTS[
            segment]
        fullname = os.path.join(default_root,
                                filename) if root is None else os.path.join(
                                    os.path.expanduser(root), filename)
        if not os.path.exists(fullname) or (
                data_hash and not md5file(fullname) == data_hash):
            if root is not None:  # not specified, and no need to warn
                warnings.warn(
                    'md5 check failed for {}, download {} data to {}'.format(
                        filename, self.__class__.__name__, default_root))
            if segment in ('train', 'dev'):
                dev_id_path = get_path_from_url(
                    self.DEV_ID_URL, os.path.join(default_root, 'MRPC'),
                    self.DEV_ID_MD5)
                train_data_path = get_path_from_url(
                    self.TRAIN_DATA_URL, os.path.join(default_root, 'MRPC'),
                    self.TRAIN_DATA_MD5)
                # read dev data ids
                dev_ids = []
                with io.open(dev_id_path, encoding='utf-8') as ids_fh:
                    for row in ids_fh:
                        dev_ids.append(row.strip().split('\t'))

                # generate train and dev set
                train_path = os.path.join(default_root, 'MRPC', 'train.tsv')
                dev_path = os.path.join(default_root, 'MRPC', 'dev.tsv')
                with io.open(train_data_path, encoding='utf-8') as data_fh:
                    with io.open(train_path, 'w',
                                 encoding='utf-8') as train_fh:
                        with io.open(dev_path, 'w', encoding='utf8') as dev_fh:
                            header = data_fh.readline()
                            train_fh.write(header)
                            dev_fh.write(header)
                            for row in data_fh:
                                label, id1, id2, s1, s2 = row.strip().split(
                                    '\t')
                                example = '%s\t%s\t%s\t%s\t%s\n' % (
                                    label, id1, id2, s1, s2)
                                if [id1, id2] in dev_ids:
                                    dev_fh.write(example)
                                else:
                                    train_fh.write(example)
            else:
                test_data_path = get_path_from_url(
                    self.TEST_DATA_URL, os.path.join(default_root, 'MRPC'),
                    self.TEST_DATA_MD5)
                test_path = os.path.join(default_root, 'MRPC', 'test.tsv')
                with io.open(test_data_path, encoding='utf-8') as data_fh:
                    with io.open(test_path, 'w', encoding='utf-8') as test_fh:
                        header = data_fh.readline()
                        test_fh.write(
                            'index\t#1 ID\t#2 ID\t#1 String\t#2 String\n')
                        for idx, row in enumerate(data_fh):
                            label, id1, id2, s1, s2 = row.strip().split('\t')
                            test_fh.write('%d\t%s\t%s\t%s\t%s\n' %
                                          (idx, id1, id2, s1, s2))
            root = default_root
        super(GlueMRPC, self)._get_data(root, segment, **kwargs)