Exemple #1
0
def get_model_file(model_name,
                   local_model_store_dir_path=os.path.join(
                       '~', '.mxnet', 'models')):
    """
    Return location for the pretrained on local file system. This function will download from online model zoo when
    model cannot be found or has mismatch. The root directory will be created if it doesn't exist.

    Parameters
    ----------
    model_name : str
        Name of the model.
    local_model_store_dir_path : str, default $MXNET_HOME/models
        Location for keeping the model parameters.

    Returns
    -------
    file_path
        Path to the requested pretrained model file.
    """
    error, sha1_hash, repo_release_tag = get_model_name_suffix_data(model_name)
    short_sha1 = sha1_hash[:8]
    file_name = '{name}-{error}-{short_sha1}.params'.format(
        name=model_name, error=error, short_sha1=short_sha1)
    local_model_store_dir_path = os.path.expanduser(local_model_store_dir_path)
    file_path = os.path.join(local_model_store_dir_path, file_name)
    if os.path.exists(file_path):
        if check_sha1(file_path, sha1_hash):
            return file_path
        else:
            logging.warning(
                'Mismatch in the content of model file detected. Downloading again.'
            )
    else:
        logging.info(
            'Model file not found. Downloading to {}.'.format(file_path))

    if not os.path.exists(local_model_store_dir_path):
        os.makedirs(local_model_store_dir_path)

    zip_file_path = file_path + '.zip'
    download(
        url='{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip'.
        format(repo_url=imgclsmob_repo_url,
               repo_release_tag=repo_release_tag,
               file_name=file_name),
        path=zip_file_path,
        overwrite=True)
    with zipfile.ZipFile(zip_file_path) as zf:
        zf.extractall(local_model_store_dir_path)
    os.remove(zip_file_path)

    if check_sha1(file_path, sha1_hash):
        return file_path
    else:
        raise ValueError(
            'Downloaded file has different hash. Please try again.')
Exemple #2
0
    def _generate(self, segment):
        """Partition MRPC dataset into train, dev and test.
        Adapted from https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e
        """
        # download raw data
        data_name = segment + '.tsv'
        raw_name, raw_hash, data_hash = self._data_file[segment]
        raw_path = os.path.join(self._root, raw_name)
        download(self._repo_dir() + raw_name, path=raw_path, sha1_hash=raw_hash)
        data_path = os.path.join(self._root, data_name)

        if segment in ('train', 'dev'):
            if os.path.isfile(data_path) and check_sha1(data_path, data_hash):
                return

            # retrieve dev ids for train and dev set
            DEV_ID_URL = 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc'
            DEV_ID_HASH = '506c7a1a5e0dd551ceec2f84070fa1a8c2bc4b41'
            dev_id_name = 'dev_ids.tsv'
            dev_id_path = os.path.join(self._root, dev_id_name)
            download(DEV_ID_URL, path=dev_id_path, sha1_hash=DEV_ID_HASH)

            # read dev data ids
            dev_ids = []
            with io.open(dev_id_path, encoding='utf8') as ids_fh:
                for row in ids_fh:
                    dev_ids.append(row.strip().split('\t'))

            # generate train and dev set
            train_path = os.path.join(self._root, 'train.tsv')
            dev_path = os.path.join(self._root, 'dev.tsv')
            with io.open(raw_path, encoding='utf8') as data_fh:
                with io.open(train_path, 'w', encoding='utf8') as train_fh:
                    with io.open(dev_path, 'w', encoding='utf8') as dev_fh:
                        header = data_fh.readline()
                        train_fh.write(header)
                        dev_fh.write(header)
                        for row in data_fh:
                            label, id1, id2, s1, s2 = row.strip().split('\t')
                            example = '%s\t%s\t%s\t%s\t%s\n'%(label, id1, id2, s1, s2)
                            if [id1, id2] in dev_ids:
                                dev_fh.write(example)
                            else:
                                train_fh.write(example)
        else:
            # generate test set
            if os.path.isfile(data_path) and check_sha1(data_path, data_hash):
                return
            with io.open(raw_path, encoding='utf8') as data_fh:
                with io.open(data_path, 'w', encoding='utf8') as test_fh:
                    header = data_fh.readline()
                    test_fh.write('index\t#1 ID\t#2 ID\t#1 String\t#2 String\n')
                    for idx, row in enumerate(data_fh):
                        label, id1, id2, s1, s2 = row.strip().split('\t')
                        test_fh.write('%d\t%s\t%s\t%s\t%s\n'%(idx, id1, id2, s1, s2))
def download(url, path=None, overwrite=False, sha1_hash=None, verify=True):
    """Download an given URL

    Parameters
    ----------
    url : str
        URL to download
    path : str, optional
        Destination path to store downloaded file. By default stores to the
        current directory with same name as in url.
    overwrite : bool, optional
        Whether to overwrite destination file if already exists.
    sha1_hash : str, optional
        Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
        but doesn't match.
    verify : bool
        Toggle verification of SSL certificates.

    Returns
    -------
    str
        The file path of the downloaded file.
    """
    if path is None:
        fname = url.split('/')[-1]
    else:
        path = os.path.expanduser(path)
        if os.path.isdir(path):
            fname = os.path.join(path, url.split('/')[-1])
        else:
            fname = path

    if overwrite or not os.path.exists(fname) or (
            sha1_hash and not check_sha1(fname, sha1_hash)):
        dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
        if not os.path.exists(dirname):
            os.makedirs(dirname)

        print('Downloading %s from %s...' % (fname, url))
        r = requests.get(url, stream=True, verify=verify)
        if r.status_code != 200:
            raise RuntimeError('Failed downloading url %s' % url)
        with open(fname, 'wb') as f:
            for chunk in r.iter_content(chunk_size=1024):
                if chunk:  # filter out keep-alive new chunks
                    f.write(chunk)

        if sha1_hash and not check_sha1(fname, sha1_hash):
            raise UserWarning('File {} is downloaded but the content hash does not match. ' \
                              'The repo may be outdated or download may be incomplete. ' \
                              'If the "repo_url" is overridden, consider switching to ' \
                              'the default repo.'.format(fname))

    return fname
def get_model_file(name, root=os.path.join('~', '.mxnet', 'models')):
    r"""Return location for the pretrained on local file system.

    This function will download from online model zoo when model cannot be found or has mismatch.
    The root directory will be created if it doesn't exist.

    Parameters
    ----------
    name : str
        Name of the model.
    root : str, default '~/.mxnet/models'
        Location for keeping the model parameters.

    Returns
    -------
    file_path
        Path to the requested pretrained model file.
    """
    file_name = '{name}-{short_hash}'.format(name=name,
                                             short_hash=short_hash(name))
    root = os.path.expanduser(root)
    file_path = os.path.join(root, file_name + '.params')
    sha1_hash = _model_sha1[name]
    print("find file path: {}".format(file_path))
    if os.path.exists(file_path):
        if check_sha1(file_path, sha1_hash):
            return file_path
        else:
            print(
                'Mismatch in the content of model file detected. Downloading again.'
            )
    else:
        print('Model file is not found. Downloading.')

    if not os.path.exists(root):
        os.makedirs(root)

    zip_file_path = os.path.join(root, file_name + '.zip')
    repo_url = os.environ.get('MXNET_GLUON_REPO', apache_repo_url)
    if repo_url[-1] != '/':
        repo_url = repo_url + '/'
    download(_url_format.format(repo_url=repo_url, file_name=file_name),
             path=zip_file_path,
             overwrite=True)
    with zipfile.ZipFile(zip_file_path) as zf:
        zf.extractall(root)
    os.remove(zip_file_path)

    if check_sha1(file_path, sha1_hash):
        return file_path
    else:
        raise ValueError(
            'Downloaded file has different hash. Please try again.')
    def __init__(self, root=os.path.join(get_home_dir(), 'models')):
        try:
            import regex  # pylint: disable=import-outside-toplevel
            self._regex = regex
        except ImportError:
            raise ImportError('GPT2BPETokenizer requires regex. '
                              'To install regex, use pip install -U regex')
        super(GPT2BPETokenizer, self).__init__()
        root = os.path.expanduser(root)
        file_name, sha1_hash = self.bpe_ranks_file_hash
        file_path = os.path.join(root, file_name)
        if not os.path.exists(file_path) or not check_sha1(
                file_path, sha1_hash):
            if os.path.exists(file_path):
                print(
                    'Detected mismatch in the content of BPE rank file. Downloading again.'
                )
            else:
                print('BPE rank file is not found. Downloading.')
            os.makedirs(root, exist_ok=True)

            prefix = str(time.time())
            zip_file_path = os.path.join(root, prefix + file_name)
            repo_url = _get_repo_url()
            if repo_url[-1] != '/':
                repo_url = repo_url + '/'
            archive_name, archive_hash = self.bpe_ranks_archive_hash
            _url_format = '{repo_url}gluon/dataset/vocab/{file_name}'
            download(_url_format.format(repo_url=repo_url,
                                        file_name=archive_name),
                     path=zip_file_path,
                     sha1_hash=archive_hash,
                     overwrite=True)
            with zipfile.ZipFile(zip_file_path) as zf:
                if not os.path.exists(file_path):
                    zf.extractall(root)
            try:
                os.remove(zip_file_path)
            except OSError as e:
                # file has already been removed.
                if e.errno == 2:
                    pass
                else:
                    raise e

            if not check_sha1(file_path, sha1_hash):
                raise ValueError(
                    'Downloaded file has different hash. Please try again.')
        self._read_bpe_ranks(file_path)
        self._cache = {}
        self._token_pattern = self._regex.compile(
            r'\'s|\'t|\'re|\'ve|\'m|\'ll|\'d| ?\p{L}+'
            r'| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+')
Exemple #6
0
def _load_pretrained_vocab(name,
                           root=os.path.join(get_home_dir(), 'models'),
                           cls=None):
    """Load the accompanying vocabulary object for pre-trained model.

    Parameters
    ----------
    name : str
        Name of the vocabulary, usually the name of the dataset.
    root : str, default '$MXNET_HOME/models'
        Location for keeping the model parameters.
        MXNET_HOME defaults to '~/.mxnet'.
    cls : nlp.Vocab or nlp.vocab.BERTVocab, default nlp.Vocab

    Returns
    -------
    Vocab or nlp.vocab.BERTVocab
        Loaded vocabulary object for the pre-trained model.
    """
    file_name = '{name}-{short_hash}'.format(name=name,
                                             short_hash=short_hash(name))
    root = os.path.expanduser(root)
    file_path = os.path.join(root, file_name + '.vocab')
    sha1_hash = _vocab_sha1[name]
    if os.path.exists(file_path):
        if check_sha1(file_path, sha1_hash):
            return _load_vocab_file(file_path, cls)
        else:
            print(
                'Detected mismatch in the content of model vocab file. Downloading again.'
            )
    else:
        print('Vocab file is not found. Downloading.')

    if not os.path.exists(root):
        os.makedirs(root)

    zip_file_path = os.path.join(root, file_name + '.zip')
    repo_url = _get_repo_url()
    if repo_url[-1] != '/':
        repo_url = repo_url + '/'
    download(_url_format.format(repo_url=repo_url, file_name=file_name),
             path=zip_file_path,
             overwrite=True)
    with zipfile.ZipFile(zip_file_path) as zf:
        zf.extractall(root)
    os.remove(zip_file_path)

    if check_sha1(file_path, sha1_hash):
        return _load_vocab_file(file_path, cls)
    else:
        raise ValueError(
            'Downloaded file has different hash. Please try again.')
Exemple #7
0
def _load_pretrained_vocab(name, root, cls=None):
    """Load the accompanying vocabulary object for pre-trained model.

    Parameters
    ----------
    name : str
        Name of the vocabulary, usually the name of the dataset.
    root : str
        Location for keeping the model vocabulary.
    cls : nlp.Vocab or nlp.vocab.BERTVocab, default nlp.Vocab

    Returns
    -------
    Vocab or nlp.vocab.BERTVocab
        Loaded vocabulary object for the pre-trained model.
    """
    file_name = '{name}-{short_hash}'.format(name=name,
                                             short_hash=short_hash(name))
    root = os.path.expanduser(root)
    file_path = os.path.join(root, file_name + '.vocab')
    sha1_hash = _vocab_sha1[name]

    temp_num = str(random.Random().randint(1, sys.maxsize))
    temp_root = os.path.join(root, temp_num)
    temp_file_path = os.path.join(temp_root, file_name + '.vocab')
    temp_zip_file_path = os.path.join(root, temp_num + file_name + '.zip')
    if os.path.exists(file_path):
        if check_sha1(file_path, sha1_hash):
            return _load_vocab_file(file_path, cls)
        else:
            print('Detected mismatch in the content of model vocab file. Downloading again.')
    else:
        print('Vocab file is not found. Downloading.')

    utils.mkdir(root)

    repo_url = _get_repo_url()
    if repo_url[-1] != '/':
        repo_url = repo_url + '/'
    download(_url_format.format(repo_url=repo_url, file_name=file_name),
             path=temp_zip_file_path, overwrite=True)
    with zipfile.ZipFile(temp_zip_file_path) as zf:
        if not os.path.exists(file_path):
            utils.mkdir(temp_root)
            zf.extractall(temp_root)
            os.replace(temp_file_path, file_path)
            shutil.rmtree(temp_root)

    if check_sha1(file_path, sha1_hash):
        return _load_vocab_file(file_path, cls)
    else:
        raise ValueError('Downloaded file has different hash. Please try again.')
Exemple #8
0
def _get_data(shape):
    hash_test_img = "355e15800642286e7fe607d87c38aeeab085b0cc"
    hash_inception_v3 = "91807dfdbd336eb3b265dd62c2408882462752b9"
    fname = utils.download("http://data.mxnet.io/data/test_images_%d_%d.npy" % (shape),
                           path="data/test_images_%d_%d.npy" % (shape),
                           sha1_hash=hash_test_img)
    if not utils.check_sha1(fname, hash_test_img):
        raise RuntimeError("File %s not downloaded completely" % ("test_images_%d_%d.npy"%(shape)))

    fname = utils.download("http://data.mxnet.io/data/inception-v3-dump.npz",
                           path='data/inception-v3-dump.npz',
                           sha1_hash=hash_inception_v3)
    if not utils.check_sha1(fname, hash_inception_v3):
        raise RuntimeError("File %s not downloaded completely" % ("inception-v3-dump.npz"))
    def __init__(self, segmenter_root=os.path.join(_get_home_dir(), 'stanford-segmenter'),
                 slf4j_root=os.path.join(_get_home_dir(), 'slf4j'),
                 java_class='edu.stanford.nlp.ie.crf.CRFClassifier'):
        is_java_exist = os.system('java -version')
        assert is_java_exist == 0, 'Java is not installed. You must install Java 8.0' \
                                   'in order to use the NLTKStanfordSegmenter'
        try:
            from nltk.tokenize import StanfordSegmenter
        except ImportError:
            raise ImportError('NLTK or relevant packages are not installed. You must install NLTK '
                              'in order to use the NLTKStanfordSegmenter. You can refer to the '
                              'official installation guide in https://www.nltk.org/install.html.')
        path_to_jar = os.path.join(segmenter_root, 'stanford-segmenter-2018-02-27',
                                   'stanford-segmenter-3.9.1.jar')
        path_to_model = os.path.join(segmenter_root, 'stanford-segmenter-2018-02-27',
                                     'data', 'pku.gz')
        path_to_dict = os.path.join(segmenter_root, 'stanford-segmenter-2018-02-27',
                                    'data', 'dict-chris6.ser.gz')
        path_to_sihan_corpora_dict = os.path.join(segmenter_root,
                                                  'stanford-segmenter-2018-02-27', 'data')
        segmenter_url = 'https://nlp.stanford.edu/software/stanford-segmenter-2018-02-27.zip'
        segmenter_sha1 = 'aa27a6433704b7b4c6a14be1c126cb4b14b8f57b'
        stanford_segmenter = os.path.join(segmenter_root, 'stanford-segmenter-2018-02-27.zip')
        if not os.path.exists(path_to_jar) or \
                not os.path.exists(path_to_model) or \
                not os.path.exists(path_to_dict) or \
                not os.path.exists(path_to_sihan_corpora_dict) or \
                not check_sha1(filename=stanford_segmenter, sha1_hash=segmenter_sha1):
            # automatically download the files from the website and place them to stanford_root
            if not os.path.exists(segmenter_root):
                os.mkdir(segmenter_root)
            download(url=segmenter_url, path=segmenter_root, sha1_hash=segmenter_sha1)
            _extract_archive(file=stanford_segmenter, target_dir=segmenter_root)

        path_to_slf4j = os.path.join(slf4j_root, 'slf4j-1.7.25', 'slf4j-api-1.7.25.jar')
        slf4j_url = 'https://www.slf4j.org/dist/slf4j-1.7.25.zip'
        slf4j_sha1 = '89ea41ad6ebe1b190139421bb7c8d981e9df1625'
        slf4j = os.path.join(slf4j_root, 'slf4j-1.7.25.zip')
        if not os.path.exists(path_to_slf4j) or \
                not check_sha1(filename=slf4j, sha1_hash=slf4j_sha1):
            # automatically download the files from the website and place them to slf4j_root
            if not os.path.exists(slf4j_root):
                os.mkdir(slf4j_root)
            download(url=slf4j_url, path=slf4j_root, sha1_hash=slf4j_sha1)
            _extract_archive(file=slf4j, target_dir=slf4j_root)
        self._tokenizer = StanfordSegmenter(java_class=java_class, path_to_jar=path_to_jar,
                                            path_to_slf4j=path_to_slf4j, path_to_dict=path_to_dict,
                                            path_to_sihan_corpora_dict=path_to_sihan_corpora_dict,
                                            path_to_model=path_to_model)
Exemple #10
0
    def __init__(self, segmenter_root=os.path.join(_get_home_dir(), 'stanford-segmenter'),
                 slf4j_root=os.path.join(_get_home_dir(), 'slf4j'),
                 java_class='edu.stanford.nlp.ie.crf.CRFClassifier'):
        is_java_exist = os.system('java -version')
        assert is_java_exist == 0, 'Java is not installed. You must install Java 8.0' \
                                   'in order to use the NLTKStanfordSegmenter'
        try:
            from nltk.tokenize import StanfordSegmenter
        except ImportError:
            raise ImportError('NLTK or relevant packages are not installed. You must install NLTK '
                              'in order to use the NLTKStanfordSegmenter. You can refer to the '
                              'official installation guide in https://www.nltk.org/install.html.')
        path_to_jar = os.path.join(segmenter_root, 'stanford-segmenter-2018-02-27',
                                   'stanford-segmenter-3.9.1.jar')
        path_to_model = os.path.join(segmenter_root, 'stanford-segmenter-2018-02-27',
                                     'data', 'pku.gz')
        path_to_dict = os.path.join(segmenter_root, 'stanford-segmenter-2018-02-27',
                                    'data', 'dict-chris6.ser.gz')
        path_to_sihan_corpora_dict = os.path.join(segmenter_root,
                                                  'stanford-segmenter-2018-02-27', 'data')
        segmenter_url = 'https://nlp.stanford.edu/software/stanford-segmenter-2018-02-27.zip'
        segmenter_sha1 = 'aa27a6433704b7b4c6a14be1c126cb4b14b8f57b'
        stanford_segmenter = os.path.join(segmenter_root, 'stanford-segmenter-2018-02-27.zip')
        if not os.path.exists(path_to_jar) or \
                not os.path.exists(path_to_model) or \
                not os.path.exists(path_to_dict) or \
                not os.path.exists(path_to_sihan_corpora_dict) or \
                not check_sha1(filename=stanford_segmenter, sha1_hash=segmenter_sha1):
            # automatically download the files from the website and place them to stanford_root
            if not os.path.exists(segmenter_root):
                os.mkdir(segmenter_root)
            download(url=segmenter_url, path=segmenter_root, sha1_hash=segmenter_sha1)
            _extract_archive(file=stanford_segmenter, target_dir=segmenter_root)

        path_to_slf4j = os.path.join(slf4j_root, 'slf4j-1.7.25', 'slf4j-api-1.7.25.jar')
        slf4j_url = 'https://www.slf4j.org/dist/slf4j-1.7.25.zip'
        slf4j_sha1 = '89ea41ad6ebe1b190139421bb7c8d981e9df1625'
        slf4j = os.path.join(slf4j_root, 'slf4j-1.7.25.zip')
        if not os.path.exists(path_to_slf4j) or \
                not check_sha1(filename=slf4j, sha1_hash=slf4j_sha1):
            # automatically download the files from the website and place them to slf4j_root
            if not os.path.exists(slf4j_root):
                os.mkdir(slf4j_root)
            download(url=slf4j_url, path=slf4j_root, sha1_hash=slf4j_sha1)
            _extract_archive(file=slf4j, target_dir=slf4j_root)
        self._tokenizer = StanfordSegmenter(java_class=java_class, path_to_jar=path_to_jar,
                                            path_to_slf4j=path_to_slf4j, path_to_dict=path_to_dict,
                                            path_to_sihan_corpora_dict=path_to_sihan_corpora_dict,
                                            path_to_model=path_to_model)
Exemple #11
0
def _get_xlnet_tokenizer(dataset_name, root):
    assert dataset_name.lower() == '126gb'
    root = os.path.expanduser(root)
    file_path = os.path.join(root, 'xlnet_126gb-871f0b3c.spiece')
    sha1_hash = '871f0b3c13b92fc5aea8fba054a214c420e302fd'
    if os.path.exists(file_path):
        if not check_sha1(file_path, sha1_hash):
            print(
                'Detected mismatch in the content of model tokenizer. Downloading again.'
            )
    else:
        print('Tokenizer file is not found. Downloading.')

    if not os.path.exists(root):
        try:
            os.makedirs(root)
        except OSError as e:
            if e.errno == errno.EEXIST and os.path.isdir(root):
                pass
            else:
                raise e

    repo_url = _get_repo_url()
    prefix = str(time.time())
    zip_file_path = os.path.join(root, prefix + 'xlnet_126gb-871f0b3c.zip')
    if repo_url[-1] != '/':
        repo_url = repo_url + '/'
    download(_url_format.format(repo_url=repo_url,
                                file_name='xlnet_126gb-871f0b3c'),
             path=zip_file_path,
             overwrite=True)
    with zipfile.ZipFile(zip_file_path) as zf:
        if not os.path.exists(file_path):
            zf.extractall(root)
    try:
        os.remove(zip_file_path)
    except OSError as e:
        # file has already been removed.
        if e.errno == 2:
            pass
        else:
            raise e

    if not check_sha1(file_path, sha1_hash):
        raise ValueError(
            'Downloaded file has different hash. Please try again.')

    tokenizer = XLNetTokenizer(file_path)
    return tokenizer
    def _download_data(self):
        _, archive_hash = self._archive_file
        for name, checksum in self._checksums.items():
            name = name.split('/')
            path = os.path.join(self.root, *name)
            if not os.path.exists(path) or not check_sha1(path, checksum):
                if self._namespace is not None:
                    url = _get_repo_file_url(self._namespace,
                                             self._archive_file[0])
                else:
                    url = self._url
                downloaded_file_path = download(url,
                                                path=self.root,
                                                sha1_hash=archive_hash,
                                                verify_ssl=self._verify_ssl)

                if downloaded_file_path.lower().endswith('zip'):
                    with zipfile.ZipFile(downloaded_file_path, 'r') as zf:
                        zf.extractall(path=self.root)
                elif downloaded_file_path.lower().endswith('tar.gz'):
                    with tarfile.open(downloaded_file_path, 'r') as tf:
                        tf.extractall(path=self.root)
                elif len(self._checksums) > 1:
                    err = 'Failed retrieving {clsname}.'.format(
                        clsname=self.__class__.__name__)
                    err += (' Expecting multiple files, '
                            'but could not detect archive format.')
                    raise RuntimeError(err)
Exemple #13
0
 def _get_data(self):
     data_file_name, data_hash = self._data_file()[self._segment]
     root = self._root
     path = os.path.join(root, data_file_name)
     if not os.path.exists(path) or not check_sha1(path, data_hash):
         download(_get_repo_file_url(self._repo_dir(), data_file_name),
                  path=root, sha1_hash=data_hash)
Exemple #14
0
    def _get_data(self):
        if any(not os.path.exists(path) or not check_sha1(path, sha1)
               for path, sha1 in ((os.path.join(self._root, name), sha1)
                                  for name, sha1 in self._train_data +
                                  self._test_data)):
            namespace = 'gluon/dataset/' + self._namespace
            filename = download(_get_repo_file_url(namespace,
                                                   self._archive_file[0]),
                                path=self._root,
                                sha1_hash=self._archive_file[1])

            with tarfile.open(filename) as tar:
                tar.extractall(self._root)

        if self._train:
            data_files = self._train_data
        else:
            data_files = self._test_data
        data, label = zip(*(self._read_batch(os.path.join(self._root, name))
                            for name, _ in data_files))
        data = np.concatenate(data)
        label = np.concatenate(label)
        if self._train:
            npr.seed(0)
            rand_inds = npr.permutation(50000)
            data = data[rand_inds]
            label = label[rand_inds]
            data = data[self._split_id * 10000:(self._split_id + 1) * 10000]
            label = label[self._split_id * 10000:(self._split_id + 1) * 10000]
        self._data = nd.array(data, dtype=data.dtype)
        self._label = label
Exemple #15
0
def check_file(filename, checksum, sha1):
    from mxnet.gluon.utils import check_sha1

    if not os.path.exists(filename):
        raise ValueError('File not found: '+filename)
    if checksum and not check_sha1(filename, sha1):
        raise ValueError('Corrupted file: '+filename)
Exemple #16
0
 def _get_data(self):
     archive_file_name, archive_hash = self._archive_file
     archive_file_path = os.path.join(self._root, archive_file_name)
     exists = False
     if os.path.exists(self._dir) and os.path.exists(self._subdir):
         # verify sha1 for all files in the subdir
         sha1 = hashlib.sha1()
         filenames = sorted(glob.glob(self._file_pattern))
         for filename in filenames:
             with open(filename, 'rb') as f:
                 while True:
                     data = f.read(1048576)
                     if not data:
                         break
                     sha1.update(data)
         if sha1.hexdigest() == self._data_hash:
             exists = True
     if not exists:
         # download archive
         if not os.path.exists(archive_file_path) or \
            not check_sha1(archive_file_path, archive_hash):
             download(_get_repo_file_url(self._namespace,
                                         archive_file_name),
                      path=self._root,
                      sha1_hash=archive_hash)
         # extract archive
         with tarfile.open(archive_file_path, 'r:gz') as tf:
             tf.extractall(path=self._root)
def get_model_file(name, root=os.path.join('~', '.mxnet', 'models')):
    r"""Return location for the pretrained on local file system.

    This function will download from online model zoo when model cannot be found or has mismatch.
    The root directory will be created if it doesn't exist.

    Parameters
    ----------
    name : str
        Name of the model.
    root : str, default '~/.mxnet/models'
        Location for keeping the model parameters.

    Returns
    -------
    file_path
        Path to the requested pretrained model file.
    """
    file_name = '{name}-{short_hash}'.format(name=name,
                                             short_hash=short_hash(name))
    root = os.path.expanduser(root)
    file_path = os.path.join(root, file_name + '.params')
    sha1_hash = _model_sha1[name]
    if os.path.exists(file_path):
        if check_sha1(file_path, sha1_hash):
            return file_path
        else:
            raise AssertionError(
                'Mismatch in the content of model file detected. Please download it again.'
            )
    else:
        raise AssertionError(
            'Model file: %s is not found. Please download before use it.' %
            file_path)
Exemple #18
0
def _load_pretrained_vocab(name, root=os.path.join('~', '.mxnet', 'models')):
    """Load the accompanying vocabulary object for pre-trained model.

    Parameters
    ----------
    name : str
        Name of the vocabulary, usually the name of the dataset.
    root : str, default '~/.mxnet/models'
        Location for keeping the model parameters.

    Returns
    -------
    Vocab
        Loaded vocabulary object for the pre-trained model.
    """
    file_name = '{name}-{short_hash}'.format(name=name,
                                             short_hash=short_hash(name))
    root = os.path.expanduser(root)
    file_path = os.path.join(root, file_name+'.vocab')
    sha1_hash = _vocab_sha1[name]
    if os.path.exists(file_path):
        if check_sha1(file_path, sha1_hash):
            return _load_vocab_file(file_path)
        else:
            print('Detected mismatch in the content of model vocab file. Downloading again.')
    else:
        print('Vocab file is not found. Downloading.')

    if not os.path.exists(root):
        os.makedirs(root)

    zip_file_path = os.path.join(root, file_name+'.zip')
    repo_url = _get_repo_url()
    if repo_url[-1] != '/':
        repo_url = repo_url + '/'
    download(_url_format.format(repo_url=repo_url, file_name=file_name),
             path=zip_file_path,
             overwrite=True)
    with zipfile.ZipFile(zip_file_path) as zf:
        zf.extractall(root)
    os.remove(zip_file_path)

    if check_sha1(file_path, sha1_hash):
        return _load_vocab_file(file_path)
    else:
        raise ValueError('Downloaded file has different hash. Please try again.')
Exemple #19
0
 def _get_data(self):
     filename_format, sha1_hash = self._download_info
     filename = filename_format.format(sha1_hash[:8])
     data_filename = os.path.join(self._root, filename)
     url = _get_repo_file_url('gluon/dataset', filename)
     if not os.path.exists(data_filename) or not check_sha1(
             data_filename, sha1_hash):
         download(url, path=data_filename, sha1_hash=sha1_hash)
         with zipfile.ZipFile(data_filename, 'r') as zf:
             zf.extractall(self._root)
Exemple #20
0
def _get_file_path(cls_name, file_name, file_hash):
    root_path = os.path.expanduser(os.path.join(get_home_dir(), 'embedding'))
    embedding_dir = os.path.join(root_path, cls_name)
    url = _get_file_url(cls_name, file_name)
    file_path = os.path.join(embedding_dir, file_name)
    if not os.path.exists(file_path) or not check_sha1(file_path, file_hash):
        logging.info(
            'Embedding file {} is not found. Downloading from Gluon Repository. '
            'This may take some time.'.format(file_name))
        download(url, file_path, sha1_hash=file_hash)
    return file_path
 def _get_data(self):
     filename, url, sha1_hash = self._download_info
     data_filename = os.path.join(self._root, filename)
     if not os.path.exists(data_filename) or not check_sha1(
             data_filename, sha1_hash):
         download(url,
                  path=data_filename,
                  sha1_hash=sha1_hash,
                  verify_ssl=True)
         with zipfile.ZipFile(data_filename, 'r') as zf:
             zf.extractall(self._root)
Exemple #22
0
def _load_pretrained_sentencepiece_tokenizer(name, root, **kwargs):
    from ..data import SentencepieceTokenizer  # pylint: disable=import-outside-toplevel
    file_name, file_ext, sha1_hash, _ = _get_vocab_tokenizer_info(name, root)
    file_path = os.path.join(root, file_name + file_ext)
    if os.path.exists(file_path):
        if check_sha1(file_path, sha1_hash):
            assert file_path.endswith('.spiece')
            return SentencepieceTokenizer(file_path, **kwargs)
        else:
            print(
                'Detected mismatch in the content of model tokenizer file. Downloading again.'
            )
    else:
        print('tokenizer file is not found. Downloading.')
    _download_vocab_tokenizer(root, file_name, file_ext, file_path)
    if check_sha1(file_path, sha1_hash):
        assert file_path.endswith('.spiece')
        return SentencepieceTokenizer(file_path, **kwargs)
    else:
        raise ValueError(
            'Downloaded file has different hash. Please try again.')
Exemple #23
0
 def _get_data(self, force_download=False):
     if not os.path.exists(self._root):
         os.makedirs(self._root)
     data_name, url, data_hash = self._raw_data_urls[self._name]
     path = os.path.join(self._root, data_name)
     if not os.path.exists(path) or force_download or (
             data_hash and not check_sha1(path, data_hash)):
         print("\n\n=====================> Download dataset ...")
         self.download(url, path=path, sha1_hash=data_hash)
         print("\n\n=====================> Unzip the file ...")
         with ZipFile(path, 'r') as zf:
             zf.extractall(path=self._root)
    def _get_data(self):
        archive_file_name, archive_hash = self.archive_file
        data_file_name, data_hash = self.data_file[self._segment]
        root = self._root
        path = os.path.join(root, data_file_name)
        if not os.path.exists(path) or not check_sha1(path, data_hash):
            downloaded_file_path = download(self.url + archive_file_name,
                                            path=root, sha1_hash=archive_hash)

            with zipfile.ZipFile(downloaded_file_path, 'r') as zf:
                zf.extractall(root)
        return path
Exemple #25
0
 def _get_data(self):
     archive_file_name, archive_hash = self._get_data_archive_hash()
     paths = []
     for data_file_name, data_hash in self._get_data_file_hash():
         root = self._root
         path = os.path.join(root, data_file_name)
         if not os.path.exists(path) or not check_sha1(path, data_hash):
             download(self.base_url + archive_file_name,
                      path=root, sha1_hash=archive_hash)
             self._extract_archive()
         paths.append(path)
     return paths
Exemple #26
0
    def _get_data(archive_file, data_file, segment, root, namespace):
        archive_file_name, archive_hash = archive_file
        data_file_name, data_hash = data_file[segment]
        path = os.path.join(root, data_file_name)
        if not os.path.exists(path) or not check_sha1(path, data_hash):
            downloaded_file_path = download(_get_repo_file_url(
                namespace, archive_file_name),
                                            path=root,
                                            sha1_hash=archive_hash)

            with zipfile.ZipFile(downloaded_file_path, 'r') as zf:
                zf.extractall(root)
        return path
    def _get_vocab(self):
        archive_file_name, archive_hash = self._archive_vocab
        vocab_file_name, vocab_hash = self._vocab_file
        namespace = 'gluon/dataset/vocab'
        root = self._root
        path = os.path.join(root, vocab_file_name)
        if not os.path.exists(path) or not check_sha1(path, vocab_hash):
            downloaded_path = download(_get_repo_file_url(namespace, archive_file_name),
                                       path=root, sha1_hash=archive_hash)

            with zipfile.ZipFile(downloaded_path, 'r') as zf:
                zf.extractall(path=root)
        return path
Exemple #28
0
 def _get_data(self):
     archive_file_name, archive_hash = self._get_data_archive_hash()
     paths = []
     for data_file_name, data_hash in self._get_data_file_hash():
         root = self._root
         path = os.path.join(root, data_file_name)
         if not os.path.exists(path) or not check_sha1(path, data_hash):
             download(self.base_url + archive_file_name,
                      path=root,
                      sha1_hash=archive_hash)
             self._extract_archive()
         paths.append(path)
     return paths
Exemple #29
0
 def _get_data(self, segment, zip_hash, data_hash, filename):
     data_filename = '%s-%s.zip' % (segment, data_hash[:8])
     if not os.path.exists(filename) or not check_sha1(filename, data_hash):
         download(_get_repo_file_url(self._repo_dir(), data_filename),
                  path=self._root, sha1_hash=zip_hash)
         # unzip
         downloaded_path = os.path.join(self._root, data_filename)
         with zipfile.ZipFile(downloaded_path, 'r') as zf:
             # skip dir structures in the zip
             for zip_info in zf.infolist():
                 if zip_info.filename[-1] == '/':
                     continue
                 zip_info.filename = os.path.basename(zip_info.filename)
                 zf.extract(zip_info, self._root)
Exemple #30
0
 def _get_data(self):
     archive_file_name, archive_hash = self._get_data_archive_hash()
     paths = []
     for data_file_name, data_hash in self._get_data_file_hash():
         root = self._root
         path = os.path.join(root, data_file_name)
         if hasattr(self, 'namespace'):
             url = _get_repo_file_url(self.namespace, archive_file_name)
         else:
             url = self.base_url + archive_file_name
         if not os.path.exists(path) or not check_sha1(path, data_hash):
             download(url, path=root, sha1_hash=archive_hash)
             self._extract_archive()
         paths.append(path)
     return paths
def download_city(path, overwrite=False):
    _CITY_DOWNLOAD_URLS = [
        ('gtFine_trainvaltest.zip', '99f532cb1af174f5fcc4c5bc8feea8c66246ddbc'),
        ('leftImg8bit_trainvaltest.zip', '2c0b77ce9933cc635adda307fbba5566f5d9d404')]
    download_dir = os.path.join(path, 'downloads')
    makedirs(download_dir)
    for filename, checksum in _CITY_DOWNLOAD_URLS:
        if not check_sha1(filename, checksum):
            raise UserWarning('File {} is downloaded but the content hash does not match. ' \
                              'The repo may be outdated or download may be incomplete. ' \
                              'If the "repo_url" is overridden, consider switching to ' \
                              'the default repo.'.format(filename))
        # extract
        with zipfile.ZipFile(filename,"r") as zip_ref:
            zip_ref.extractall(path=path)
        print("Extracted", filename)
    def _get_file_path(cls, source_file_hash, embedding_root, source):
        cls_name = cls.__name__.lower()
        embedding_root = os.path.expanduser(embedding_root)
        url = cls._get_file_url(cls_name, source_file_hash, source)

        embedding_dir = os.path.join(embedding_root, cls_name)

        pretrained_file_name, expected_file_hash = source_file_hash[source]
        pretrained_file_path = os.path.join(embedding_dir, pretrained_file_name)

        if not os.path.exists(pretrained_file_path) \
           or not check_sha1(pretrained_file_path, expected_file_hash):
            print('Embedding file {} is not found. Downloading from Gluon Repository. '
                  'This may take some time.'.format(pretrained_file_name))
            download(url, pretrained_file_path, sha1_hash=expected_file_hash)

        return pretrained_file_path
    def _get_file_path(cls, source_file_hash, embedding_root, source):
        cls_name = cls.__name__.lower()
        embedding_root = os.path.expanduser(embedding_root)
        url = cls._get_file_url(cls_name, source_file_hash, source)

        embedding_dir = os.path.join(embedding_root, cls_name)

        pretrained_file_name, expected_file_hash = source_file_hash[source]
        pretrained_file_path = os.path.join(embedding_dir, pretrained_file_name)

        if not os.path.exists(pretrained_file_path) \
           or not check_sha1(pretrained_file_path, expected_file_hash):
            print('Embedding file {} is not found. Downloading from Gluon Repository. '
                  'This may take some time.'.format(pretrained_file_name))
            download(url, pretrained_file_path, sha1_hash=expected_file_hash)

        return pretrained_file_path
Exemple #34
0
def download_city(path, overwrite=False):
    _CITY_DOWNLOAD_URLS = [('gtFine_trainvaltest.zip',
                            '99f532cb1af174f5fcc4c5bc8feea8c66246ddbc'),
                           ('leftImg8bit_trainvaltest.zip',
                            '2c0b77ce9933cc635adda307fbba5566f5d9d404')]
    download_dir = os.path.join(path, 'downloads')
    makedirs(download_dir)
    for filename, checksum in _CITY_DOWNLOAD_URLS:
        if not check_sha1(filename, checksum):
            raise UserWarning('File {} is downloaded but the content hash does not match. ' \
                              'The repo may be outdated or download may be incomplete. ' \
                              'If the "repo_url" is overridden, consider switching to ' \
                              'the default repo.'.format(filename))
        # extract
        with zipfile.ZipFile(filename, "r") as zip_ref:
            zip_ref.extractall(path=path)
        print("Extracted", filename)
Exemple #35
0
    def _get_data(self):
        """Load data from the file. Do nothing if data was loaded before.
        """
        (data_archive_name, archive_hash), (data_name, data_hash) \
            = self._data_file()[self._segment]
        data_path = os.path.join(self._root, data_name)

        if not os.path.exists(data_path) or not check_sha1(data_path, data_hash):
            file_path = download(_get_repo_file_url(self._repo_dir(), data_archive_name),
                                 path=self._root, sha1_hash=archive_hash)

            with zipfile.ZipFile(file_path, 'r') as zf:
                for member in zf.namelist():
                    filename = os.path.basename(member)
                    if filename:
                        dest = os.path.join(self._root, filename)
                        with zf.open(member) as source, open(dest, 'wb') as target:
                            shutil.copyfileobj(source, target)
    def _get_data(self):
        """Load data from the file. Does nothing if data was loaded before
        """
        data_archive_name, _, data_hash = self._data_file[self._segment]
        path = os.path.join(self._root, data_archive_name)

        if not os.path.exists(path) or not check_sha1(path, data_hash):
            file_path = download(_get_repo_file_url('gluon/dataset/squad', data_archive_name),
                                 path=self._root, sha1_hash=data_hash)

            with zipfile.ZipFile(file_path, 'r') as zf:
                for member in zf.namelist():
                    filename = os.path.basename(member)

                    if filename:
                        dest = os.path.join(self._root, filename)

                        with zf.open(member) as source, open(dest, 'wb') as target:
                            shutil.copyfileobj(source, target)
Exemple #37
0
    def _get_data(self):
        archive_file_name, archive_hash = self._archive_file
        data_file_name, data_hash = self._data_file[self._segment]
        root = self._root
        path = os.path.join(root, data_file_name)
        if not os.path.exists(path) or not check_sha1(path, data_hash):
            downloaded_file_path = download(_get_repo_file_url(self._namespace, archive_file_name),
                                            path=root,
                                            sha1_hash=archive_hash)

            with zipfile.ZipFile(downloaded_file_path, 'r') as zf:
                for member in zf.namelist():
                    filename = os.path.basename(member)
                    if filename:
                        dest = os.path.join(root, filename)
                        with zf.open(member) as source, \
                             open(dest, 'wb') as target:
                            shutil.copyfileobj(source, target)
        return path
def check_file(filename, checksum, sha1):
    if not os.path.exists(filename):
        raise ValueError('File not found: '+filename)
    if checksum and not check_sha1(filename, sha1):
        raise ValueError('Corrupted file: '+filename)