Пример #1
0
 def _get_data(self):
     data_file_name, data_hash = self._data_file()[self._segment]
     root = self._root
     path = os.path.join(root, data_file_name)
     if not os.path.exists(path) or not check_sha1(path, data_hash):
         download(_get_repo_file_url(self._repo_dir(), data_file_name),
                  path=root, sha1_hash=data_hash)
Пример #2
0
    def _get_data(self):
        if any(not os.path.exists(path) or not check_sha1(path, sha1)
               for path, sha1 in ((os.path.join(self._root, name), sha1)
                                  for name, sha1 in self._train_data +
                                  self._test_data)):
            namespace = 'gluon/dataset/' + self._namespace
            filename = download(_get_repo_file_url(namespace,
                                                   self._archive_file[0]),
                                path=self._root,
                                sha1_hash=self._archive_file[1])

            with tarfile.open(filename) as tar:
                tar.extractall(self._root)

        if self._train:
            data_files = self._train_data
        else:
            data_files = self._test_data
        data, label = zip(*(self._read_batch(os.path.join(self._root, name))
                            for name, _ in data_files))
        data = np.concatenate(data)
        label = np.concatenate(label)
        if self._train:
            npr.seed(0)
            rand_inds = npr.permutation(50000)
            data = data[rand_inds]
            label = label[rand_inds]
            data = data[self._split_id * 10000:(self._split_id + 1) * 10000]
            label = label[self._split_id * 10000:(self._split_id + 1) * 10000]
        self._data = nd.array(data, dtype=data.dtype)
        self._label = label
Пример #3
0
 def _get_data(self, arg_segment, zip_hash, data_hash, filename):
     # The GLUE API requires "dev", but these files are hashed as "val"
     if self.task == "MultiRC":
         # The MultiRC version from Gluon is quite outdated.
         # Make sure you've downloaded it and extracted it as described
         # in the README.MD file.
         print("Make sure you have downloaded the data!")
         print(
             "https://github.com/nyu-mll/jiant/blob/master/scripts/download_superglue_data.py"
         )
     segment = "val" if arg_segment == "dev" else arg_segment
     data_filename = '%s-%s.zip' % (segment, data_hash[:8])
     if not os.path.exists(filename) and self.task != "MultiRC":
         download(_get_repo_file_url(self._repo_dir(), data_filename),
                  path=self._root,
                  sha1_hash=zip_hash)
         # unzip
         downloaded_path = os.path.join(self._root, data_filename)
         with zipfile.ZipFile(downloaded_path, 'r') as zf:
             # skip dir structures in the zip
             for zip_info in zf.infolist():
                 if zip_info.filename[-1] == '/':
                     continue
                 zip_info.filename = os.path.basename(zip_info.filename)
                 zf.extract(zip_info, self._root)
Пример #4
0
 def _get_data(self):
     archive_file_name, archive_hash = self._archive_file
     archive_file_path = os.path.join(self._root, archive_file_name)
     exists = False
     if os.path.exists(self._dir) and os.path.exists(self._subdir):
         # verify sha1 for all files in the subdir
         sha1 = hashlib.sha1()
         filenames = sorted(glob.glob(self._file_pattern))
         for filename in filenames:
             with open(filename, 'rb') as f:
                 while True:
                     data = f.read(1048576)
                     if not data:
                         break
                     sha1.update(data)
         if sha1.hexdigest() == self._data_hash:
             exists = True
     if not exists:
         # download archive
         if not os.path.exists(archive_file_path) or \
            not check_sha1(archive_file_path, archive_hash):
             download(_get_repo_file_url(self._namespace,
                                         archive_file_name),
                      path=self._root,
                      sha1_hash=archive_hash)
         # extract archive
         with tarfile.open(archive_file_path, 'r:gz') as tf:
             tf.extractall(path=self._root)
    def _download_data(self):
        _, archive_hash = self._archive_file
        for name, checksum in self._checksums.items():
            name = name.split('/')
            path = os.path.join(self.root, *name)
            if not os.path.exists(path) or not check_sha1(path, checksum):
                if self._namespace is not None:
                    url = _get_repo_file_url(self._namespace,
                                             self._archive_file[0])
                else:
                    url = self._url
                downloaded_file_path = download(url,
                                                path=self.root,
                                                sha1_hash=archive_hash,
                                                verify_ssl=self._verify_ssl)

                if downloaded_file_path.lower().endswith('zip'):
                    with zipfile.ZipFile(downloaded_file_path, 'r') as zf:
                        zf.extractall(path=self.root)
                elif downloaded_file_path.lower().endswith('tar.gz'):
                    with tarfile.open(downloaded_file_path, 'r') as tf:
                        tf.extractall(path=self.root)
                elif len(self._checksums) > 1:
                    err = 'Failed retrieving {clsname}.'.format(
                        clsname=self.__class__.__name__)
                    err += (' Expecting multiple files, '
                            'but could not detect archive format.')
                    raise RuntimeError(err)
Пример #6
0
class FashionMNIST(object):
    namespace = "gluon/dataset/fashion-mnist"

    train_data = (
        utils._get_repo_file_url(namespace, "train-images-idx3-ubyte.gz"),
        "0cf37b0d40ed5169c6b3aba31069a9770ac9043d",
    )
    train_label = (
        utils._get_repo_file_url(namespace, "train-labels-idx1-ubyte.gz"),
        "236021d52f1e40852b06a4c3008d8de8aef1e40b",
    )
    test_data = (
        utils._get_repo_file_url(namespace, "t10k-images-idx3-ubyte.gz"),
        "626ed6a7c06dd17c0eec72fa3be1740f146a2863",
    )
    test_label = (
        utils._get_repo_file_url(namespace, "t10k-labels-idx1-ubyte.gz"),
        "17f9ab60e7257a1620f4ad76bbbaf857c3920701",
    )
Пример #7
0
class MNIST(object):
    namespace = "gluon/dataset/mnist"

    train_data = (
        utils._get_repo_file_url(namespace, "train-images-idx3-ubyte.gz"),
        "6c95f4b05d2bf285e1bfb0e7960c31bd3b3f8a7d",
    )
    train_label = (
        utils._get_repo_file_url(namespace, "train-labels-idx1-ubyte.gz"),
        "2a80914081dc54586dbdf242f9805a6b8d2a15fc",
    )
    test_data = (
        utils._get_repo_file_url(namespace, "t10k-images-idx3-ubyte.gz"),
        "c3a25af1f52dad7f726cce8cacb138654b760d48",
    )
    test_label = (
        utils._get_repo_file_url(namespace, "t10k-labels-idx1-ubyte.gz"),
        "763e7fa3757d93b0cdec073cef058b2004252c17",
    )
Пример #8
0
 def _get_data(self):
     filename_format, sha1_hash = self._download_info
     filename = filename_format.format(sha1_hash[:8])
     data_filename = os.path.join(self._root, filename)
     url = _get_repo_file_url('gluon/dataset', filename)
     if not os.path.exists(data_filename) or not check_sha1(
             data_filename, sha1_hash):
         download(url, path=data_filename, sha1_hash=sha1_hash)
         with zipfile.ZipFile(data_filename, 'r') as zf:
             zf.extractall(self._root)
Пример #9
0
    def _get_vocab(self):
        archive_file_name, archive_hash = self._archive_vocab
        vocab_file_name, vocab_hash = self._vocab_file
        namespace = 'gluon/dataset/vocab'
        root = self._root
        path = os.path.join(root, vocab_file_name)
        if not os.path.exists(path) or not check_sha1(path, vocab_hash):
            downloaded_path = download(_get_repo_file_url(namespace, archive_file_name),
                                       path=root, sha1_hash=archive_hash)

            with zipfile.ZipFile(downloaded_path, 'r') as zf:
                zf.extractall(path=root)
        return path
Пример #10
0
    def _get_data(archive_file, data_file, segment, root, namespace):
        archive_file_name, archive_hash = archive_file
        data_file_name, data_hash = data_file[segment]
        path = os.path.join(root, data_file_name)
        if not os.path.exists(path) or not check_sha1(path, data_hash):
            downloaded_file_path = download(_get_repo_file_url(
                namespace, archive_file_name),
                                            path=root,
                                            sha1_hash=archive_hash)

            with zipfile.ZipFile(downloaded_file_path, 'r') as zf:
                zf.extractall(root)
        return path
Пример #11
0
 def _get_data(self, segment, zip_hash, data_hash, filename):
     data_filename = '%s-%s.zip' % (segment, data_hash[:8])
     if not os.path.exists(filename) or not check_sha1(filename, data_hash):
         download(_get_repo_file_url(self._repo_dir(), data_filename),
                  path=self._root, sha1_hash=zip_hash)
         # unzip
         downloaded_path = os.path.join(self._root, data_filename)
         with zipfile.ZipFile(downloaded_path, 'r') as zf:
             # skip dir structures in the zip
             for zip_info in zf.infolist():
                 if zip_info.filename[-1] == '/':
                     continue
                 zip_info.filename = os.path.basename(zip_info.filename)
                 zf.extract(zip_info, self._root)
Пример #12
0
 def _get_data(self):
     archive_file_name, archive_hash = self._get_data_archive_hash()
     paths = []
     for data_file_name, data_hash in self._get_data_file_hash():
         root = self._root
         path = os.path.join(root, data_file_name)
         if hasattr(self, 'namespace'):
             url = _get_repo_file_url(self.namespace, archive_file_name)
         else:
             url = self.base_url + archive_file_name
         if not os.path.exists(path) or not check_sha1(path, data_hash):
             download(url, path=root, sha1_hash=archive_hash)
             self._extract_archive()
         paths.append(path)
     return paths
Пример #13
0
    def _get_data(self):
        """Load data from the file. Do nothing if data was loaded before.
        """
        (data_archive_name, archive_hash), (data_name, data_hash) \
            = self._data_file()[self._segment]
        data_path = os.path.join(self._root, data_name)

        if not os.path.exists(data_path) or not check_sha1(data_path, data_hash):
            file_path = download(_get_repo_file_url(self._repo_dir(), data_archive_name),
                                 path=self._root, sha1_hash=archive_hash)

            with zipfile.ZipFile(file_path, 'r') as zf:
                for member in zf.namelist():
                    filename = os.path.basename(member)
                    if filename:
                        dest = os.path.join(self._root, filename)
                        with zf.open(member) as source, open(dest, 'wb') as target:
                            shutil.copyfileobj(source, target)
Пример #14
0
    def _get_data(self):
        archive_file_name, archive_hash = self._archive_file
        data_file_name, data_hash = self._data_file[self._segment]
        root = self._root
        path = os.path.join(root, data_file_name)
        if not os.path.exists(path) or not check_sha1(path, data_hash):
            downloaded_file_path = download(_get_repo_file_url(self._namespace, archive_file_name),
                                            path=root,
                                            sha1_hash=archive_hash)

            with zipfile.ZipFile(downloaded_file_path, 'r') as zf:
                for member in zf.namelist():
                    filename = os.path.basename(member)
                    if filename:
                        dest = os.path.join(root, filename)
                        with zf.open(member) as source, \
                             open(dest, 'wb') as target:
                            shutil.copyfileobj(source, target)
        return path
Пример #15
0
    def _get_data(self):
        """Load data from the file. Does nothing if data was loaded before
        """
        data_archive_name, _, data_hash = self._data_file[self._segment]
        path = os.path.join(self._root, data_archive_name)

        if not os.path.exists(path) or not check_sha1(path, data_hash):
            file_path = download(_get_repo_file_url('gluon/dataset/squad', data_archive_name),
                                 path=self._root, sha1_hash=data_hash)

            with zipfile.ZipFile(file_path, 'r') as zf:
                for member in zf.namelist():
                    filename = os.path.basename(member)

                    if filename:
                        dest = os.path.join(self._root, filename)

                        with zf.open(member) as source, open(dest, 'wb') as target:
                            shutil.copyfileobj(source, target)
Пример #16
0
    def _get_data(self):
        archive_file_name, archive_hash = self._archive_file
        data_file_name, data_hash = self._data_file[self._segment]
        root = self._root
        path = os.path.join(root, data_file_name)
        if not os.path.exists(path) or not check_sha1(path, data_hash):
            downloaded_file_path = download(_get_repo_file_url(self._namespace, archive_file_name),
                                            path=root,
                                            sha1_hash=archive_hash)

            with zipfile.ZipFile(downloaded_file_path, 'r') as zf:
                for member in zf.namelist():
                    filename = os.path.basename(member)
                    if filename:
                        dest = os.path.join(root, filename)
                        with zf.open(member) as source, \
                                open(dest, 'wb') as target:
                            shutil.copyfileobj(source, target)
        return path
Пример #17
0
    def _get_data(self):
        """Load data from the file. Does nothing if data was loaded before.
        """
        (data_archive_name, archive_hash), (data_name, data_hash) \
            = self._data_file[self._version][self._segment]
        data_path = os.path.join(self._root, data_name)

        if not os.path.exists(data_path) or not check_sha1(
                data_path, data_hash):
            with tempfile.TemporaryDirectory(dir=self._root) as temp_dir:
                file_path = download(_get_repo_file_url(
                    'gluon/dataset/squad', data_archive_name),
                                     path=temp_dir,
                                     sha1_hash=archive_hash)
                with zipfile.ZipFile(file_path, 'r') as zf:
                    for member in zf.namelist():
                        filename = os.path.basename(member)
                        if filename:
                            dest = os.path.join(self._root, filename)
                            temp_dst = dest + str(uuid.uuid4())
                            with zf.open(member) as source:
                                with open(temp_dst, 'wb') as target:
                                    shutil.copyfileobj(source, target)
                                    replace_file(temp_dst, dest)
Пример #18
0
    def _get_file_url(cls, source):
        cls_name = cls.__name__.lower()

        namespace = 'gluon/embeddings/{}'.format(cls_name)
        return _get_repo_file_url(namespace, cls.source_file_hash[source][0])
Пример #19
0
 def _get_file_url(cls_name, source_file_hash, source):
     namespace = 'gluon/embeddings/{}'.format(cls_name)
     return _get_repo_file_url(namespace, source_file_hash[source][0])
Пример #20
0
    def _get_data(self):

        if any(not os.path.exists(path) or not check_sha1(path, sha1)
               for path, sha1 in ((os.path.join(self._root, name), sha1)
                                  for name, sha1 in self._train_data +
                                  self._test_data)):
            namespace = 'gluon/dataset/' + self._namespace
            filename = download(_get_repo_file_url(namespace,
                                                   self._archive_file[0]),
                                path=self._root,
                                sha1_hash=self._archive_file[1])

            with tarfile.open(filename) as tar:
                tar.extractall(self._root)

        if self._train:
            data_files = self._train_data
        else:
            data_files = self._test_data
        data, label = zip(*(self._read_batch(os.path.join(self._root, name))
                            for name, _ in data_files))

        data = np.concatenate(data)
        label = np.concatenate(label)
        if not self._fix_class:
            np.random.seed(0)
            classes = np.random.choice(np.max(label) + 1,
                                       size=self._c_way,
                                       replace=False)
            self._fix_class = list(classes)
        if self._logger:
            self._logger.info(
                'select CIFAR100 classes : {} , fine label = {}, train = {}'.
                format(self._fix_class, self._fine_label, self._train))

        if self._train:
            select_index = list()
            new_label = list()
            for i, l in enumerate(self._fix_class):
                ind = list(np.where(l == label)[0])
                np.random.seed(1)
                random_ind = np.random.choice(ind, self._k_shot, replace=False)
                select_index.extend(random_ind)
                new_label.extend([i + self._base_class] * len(random_ind))
            data = data[select_index]
            label = np.array(new_label)
        else:
            select_index = list()
            new_label = list()
            for i, l in enumerate(self._fix_class):
                ind = list(np.where(l == label)[0])
                select_index.extend(ind)
                new_label.extend([i + self._base_class] * len(ind))
            data = data[select_index]
            label = np.array(new_label)

        self._data = nd.array(data, dtype=data.dtype)
        self._label = label
        if self._logger:
            self._logger.info('the number of cifar100 new class samples : %d' %
                              (label.shape[0]))
Пример #21
0
def _get_file_url(cls_name, file_name):
    namespace = 'gluon/embeddings/{}'.format(cls_name)
    return _get_repo_file_url(namespace, file_name)
Пример #22
0
 def _get_file_url(cls_name, source_file_hash, source):
     namespace = 'gluon/embeddings/{}'.format(cls_name)
     return _get_repo_file_url(namespace, source_file_hash[source][0])