示例#1
0
 def _get_data(self):
     data_file_name, data_hash = self._data_file()[self._segment]
     root = self._root
     path = os.path.join(root, data_file_name)
     if not os.path.exists(path) or not check_sha1(path, data_hash):
         download(_get_repo_file_url(self._repo_dir(), data_file_name),
                  path=root, sha1_hash=data_hash)
示例#2
0
文件: tf.py 项目: LANHUIYING/tvm
def get_workload_ptb():
    """ Import ptb workload from frozen protobuf

    Parameters
    ----------
        Nothing.

    Returns
    -------
    graph_def: graphdef
        graph_def is the tensorflow workload for ptb.

    word_to_id : dict
        English word to integer id mapping

    id_to_word : dict
        Integer id to English word mapping
    """
    sample_repo = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/'
    sample_data_file = 'simple-examples.tgz'
    sample_url = sample_repo+sample_data_file
    ptb_model_file = 'RNN/ptb/ptb_model_with_lstmblockcell.pb'

    import tarfile
    from tvm.contrib.download import download
    DATA_DIR = './ptb_data/'
    if not os.path.exists(DATA_DIR):
        os.mkdir(DATA_DIR)
    download(sample_url, DATA_DIR+sample_data_file)
    t = tarfile.open(DATA_DIR+sample_data_file, 'r')
    t.extractall(DATA_DIR)

    word_to_id, id_to_word = _create_ptb_vocabulary(DATA_DIR)
    return word_to_id, id_to_word, get_workload(ptb_model_file)
def _download_pikachu(data_dir):
    root_url = ('https://apache-mxnet.s3-accelerate.amazonaws.com/'
                'gluon/dataset/pikachu/')
    dataset = {'train.rec': 'e6bcb6ffba1ac04ff8a9b1115e650af56ee969c8',
               'train.idx': 'dcf7318b2602c06428b9988470c731621716c393',
               'val.rec': 'd6c33f799b4d058e82f2cb5bd9a976f69d72d520'}
    for k, v in dataset.items():
        gutils.download(root_url + k, os.path.join(data_dir, k), sha1_hash=v)
示例#4
0
def _get_data(shape):
    hash_test_img = "355e15800642286e7fe607d87c38aeeab085b0cc"
    hash_inception_v3 = "91807dfdbd336eb3b265dd62c2408882462752b9"
    utils.download("http://data.mxnet.io/data/test_images_%d_%d.npy" % (shape),
                   path="data/test_images_%d_%d.npy" % (shape),
                   sha1_hash=hash_test_img)
    utils.download("http://data.mxnet.io/data/inception-v3-dump.npz",
                   path='data/inception-v3-dump.npz',
                   sha1_hash=hash_inception_v3)
示例#5
0
 def _get_data(self):
     archive_file_name, archive_hash = self._get_data_archive_hash()
     paths = []
     for data_file_name, data_hash in self._get_data_file_hash():
         root = self._root
         path = os.path.join(root, data_file_name)
         if not os.path.exists(path) or not check_sha1(path, data_hash):
             download(self.base_url + archive_file_name,
                      path=root, sha1_hash=archive_hash)
             self._extract_archive()
         paths.append(path)
     return paths
示例#6
0
    def __init__(self, segmenter_root=os.path.join(_get_home_dir(), 'stanford-segmenter'),
                 slf4j_root=os.path.join(_get_home_dir(), 'slf4j'),
                 java_class='edu.stanford.nlp.ie.crf.CRFClassifier'):
        is_java_exist = os.system('java -version')
        assert is_java_exist == 0, 'Java is not installed. You must install Java 8.0' \
                                   'in order to use the NLTKStanfordSegmenter'
        try:
            from nltk.tokenize import StanfordSegmenter
        except ImportError:
            raise ImportError('NLTK or relevant packages are not installed. You must install NLTK '
                              'in order to use the NLTKStanfordSegmenter. You can refer to the '
                              'official installation guide in https://www.nltk.org/install.html.')
        path_to_jar = os.path.join(segmenter_root, 'stanford-segmenter-2018-02-27',
                                   'stanford-segmenter-3.9.1.jar')
        path_to_model = os.path.join(segmenter_root, 'stanford-segmenter-2018-02-27',
                                     'data', 'pku.gz')
        path_to_dict = os.path.join(segmenter_root, 'stanford-segmenter-2018-02-27',
                                    'data', 'dict-chris6.ser.gz')
        path_to_sihan_corpora_dict = os.path.join(segmenter_root,
                                                  'stanford-segmenter-2018-02-27', 'data')
        segmenter_url = 'https://nlp.stanford.edu/software/stanford-segmenter-2018-02-27.zip'
        segmenter_sha1 = 'aa27a6433704b7b4c6a14be1c126cb4b14b8f57b'
        stanford_segmenter = os.path.join(segmenter_root, 'stanford-segmenter-2018-02-27.zip')
        if not os.path.exists(path_to_jar) or \
                not os.path.exists(path_to_model) or \
                not os.path.exists(path_to_dict) or \
                not os.path.exists(path_to_sihan_corpora_dict) or \
                not check_sha1(filename=stanford_segmenter, sha1_hash=segmenter_sha1):
            # automatically download the files from the website and place them to stanford_root
            if not os.path.exists(segmenter_root):
                os.mkdir(segmenter_root)
            download(url=segmenter_url, path=segmenter_root, sha1_hash=segmenter_sha1)
            _extract_archive(file=stanford_segmenter, target_dir=segmenter_root)

        path_to_slf4j = os.path.join(slf4j_root, 'slf4j-1.7.25', 'slf4j-api-1.7.25.jar')
        slf4j_url = 'https://www.slf4j.org/dist/slf4j-1.7.25.zip'
        slf4j_sha1 = '89ea41ad6ebe1b190139421bb7c8d981e9df1625'
        slf4j = os.path.join(slf4j_root, 'slf4j-1.7.25.zip')
        if not os.path.exists(path_to_slf4j) or \
                not check_sha1(filename=slf4j, sha1_hash=slf4j_sha1):
            # automatically download the files from the website and place them to slf4j_root
            if not os.path.exists(slf4j_root):
                os.mkdir(slf4j_root)
            download(url=slf4j_url, path=slf4j_root, sha1_hash=slf4j_sha1)
            _extract_archive(file=slf4j, target_dir=slf4j_root)
        self._tokenizer = StanfordSegmenter(java_class=java_class, path_to_jar=path_to_jar,
                                            path_to_slf4j=path_to_slf4j, path_to_dict=path_to_dict,
                                            path_to_sihan_corpora_dict=path_to_sihan_corpora_dict,
                                            path_to_model=path_to_model)
示例#7
0
文件: utils.py 项目: tsintian/d2l-zh
def download_imdb(data_dir='../data'):
    """Download the IMDB data set for sentiment analysis."""
    url = ('http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz')
    sha1 = '01ada507287d82875905620988597833ad4e0903'
    fname = gutils.download(url, data_dir, sha1_hash=sha1)
    with tarfile.open(fname, 'r') as f:
        f.extractall(data_dir)
示例#8
0
    def _download_data(self):
        _, archive_hash = self._archive_file
        for name, checksum in self._checksums.items():
            name = name.split('/')
            path = os.path.join(self.root, *name)
            if not os.path.exists(path) or not check_sha1(path, checksum):
                if self._namespace is not None:
                    url = _get_repo_file_url(self._namespace,
                                             self._archive_file[0])
                else:
                    url = self._url
                downloaded_file_path = download(url, path=self.root,
                                                sha1_hash=archive_hash,
                                                verify_ssl=self._verify_ssl)

                if downloaded_file_path.lower().endswith('zip'):
                    with zipfile.ZipFile(downloaded_file_path, 'r') as zf:
                        zf.extractall(path=self.root)
                elif downloaded_file_path.lower().endswith('tar.gz'):
                    with tarfile.open(downloaded_file_path, 'r') as tf:
                        tf.extractall(path=self.root)
                elif len(self._checksums) > 1:
                    err = 'Failed retrieving {clsname}.'.format(
                        clsname=self.__class__.__name__)
                    err += (' Expecting multiple files, '
                            'but could not detect archive format.')
                    raise RuntimeError(err)
示例#9
0
def download_imdb(data_dir='../data'):
    """Download the IMDB data set for sentiment analysis."""
    url = ('http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz')
    sha1 = '01ada507287d82875905620988597833ad4e0903'
    fname = gutils.download(url, data_dir, sha1_hash=sha1)
    with tarfile.open(fname, 'r') as f:
        f.extractall(data_dir)
示例#10
0
    def _get_data(self):
        if any(not os.path.exists(path) or not check_sha1(path, sha1)
               for path, sha1 in ((os.path.join(self._root, name), sha1)
                                  for name, sha1 in self._train_data +
                                  self._test_data)):
            namespace = 'gluon/dataset/' + self._namespace
            filename = download(_get_repo_file_url(namespace,
                                                   self._archive_file[0]),
                                path=self._root,
                                sha1_hash=self._archive_file[1])

            with tarfile.open(filename) as tar:
                tar.extractall(self._root)

        if self._train:
            data_files = self._train_data
        else:
            data_files = self._test_data
        data, label = zip(*(self._read_batch(os.path.join(self._root, name))
                            for name, _ in data_files))
        data = np.concatenate(data)
        label = np.concatenate(label)
        if self._train:
            npr.seed(0)
            rand_inds = npr.permutation(50000)
            data = data[rand_inds]
            label = label[rand_inds]
            data = data[self._split_id * 10000:(self._split_id + 1) * 10000]
            label = label[self._split_id * 10000:(self._split_id + 1) * 10000]
        self._data = nd.array(data, dtype=data.dtype)
        self._label = label
示例#11
0
def _load_pretrained_vocab(name, root=os.path.join('~', '.mxnet', 'models')):
    """Load the accompanying vocabulary object for pre-trained model.

    Parameters
    ----------
    name : str
        Name of the vocabulary, usually the name of the dataset.
    root : str, default '~/.mxnet/models'
        Location for keeping the model parameters.

    Returns
    -------
    Vocab
        Loaded vocabulary object for the pre-trained model.
    """
    file_name = '{name}-{short_hash}'.format(name=name,
                                             short_hash=short_hash(name))
    root = os.path.expanduser(root)
    file_path = os.path.join(root, file_name+'.vocab')
    sha1_hash = _vocab_sha1[name]
    if os.path.exists(file_path):
        if check_sha1(file_path, sha1_hash):
            return _load_vocab_file(file_path)
        else:
            print('Detected mismatch in the content of model vocab file. Downloading again.')
    else:
        print('Vocab file is not found. Downloading.')

    if not os.path.exists(root):
        os.makedirs(root)

    zip_file_path = os.path.join(root, file_name+'.zip')
    repo_url = _get_repo_url()
    if repo_url[-1] != '/':
        repo_url = repo_url + '/'
    download(_url_format.format(repo_url=repo_url, file_name=file_name),
             path=zip_file_path,
             overwrite=True)
    with zipfile.ZipFile(zip_file_path) as zf:
        zf.extractall(root)
    os.remove(zip_file_path)

    if check_sha1(file_path, sha1_hash):
        return _load_vocab_file(file_path)
    else:
        raise ValueError('Downloaded file has different hash. Please try again.')
示例#12
0
def test_sentencepiece_tokenizer():
    url_format = 'https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/dataset/vocab/{}'
    filename = 'test-0690baed.bpe'
    download(url_format.format(filename), path=os.path.join('tests', 'data', filename))
    tokenizer = t.SentencepieceTokenizer(os.path.join('tests', 'data', filename))
    detokenizer = t.SentencepieceDetokenizer(os.path.join('tests', 'data', filename))
    text = "Introducing Gluon: An Easy-to-Use Programming Interface for Flexible Deep Learning."
    try:
        ret = tokenizer(text)
        detext = detokenizer(ret)
    except ImportError:
        warnings.warn("Sentencepiece not installed, skip test_sentencepiece_tokenizer().")
        return
    assert isinstance(ret, list)
    assert all(t in tokenizer.tokens for t in ret)
    assert len(ret) > 0
    assert text == detext
示例#13
0
文件: P2PGAN.py 项目: wshaow/GAN
def download_data(dataset):
    if not os.path.exists(dataset):  # 只在数据集不存在时才下载
        url = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/%s.tar.gz' % (dataset)
        os.mkdir(dataset)  # 创建数据集文件
        data_file = utils.download(url)  # 到指定的url下下载数据集 data_file 就是数据集的名字
        with tarfile.open(data_file) as tar:  # 解压文件到指定目录下(当前目录)
            tar.extractall(path='.')
        os.remove(data_file)  # 解压完后将压缩包文件删除
示例#14
0
def _load_pretrained_vocab(name, root=os.path.join('~', '.mxnet', 'models')):
    """Load the accompanying vocabulary object for pre-trained model.

    Parameters
    ----------
    name : str
        Name of the vocabulary, usually the name of the dataset.
    root : str, default '~/.mxnet/models'
        Location for keeping the model parameters.

    Returns
    -------
    Vocab
        Loaded vocabulary object for the pre-trained model.
    """
    file_name = '{name}-{short_hash}'.format(name=name,
                                             short_hash=short_hash(name))
    root = os.path.expanduser(root)
    file_path = os.path.join(root, file_name+'.vocab')
    sha1_hash = _vocab_sha1[name]
    if os.path.exists(file_path):
        if check_sha1(file_path, sha1_hash):
            return _load_vocab_file(file_path)
        else:
            print('Detected mismatch in the content of model vocab file. Downloading again.')
    else:
        print('Vocab file is not found. Downloading.')

    if not os.path.exists(root):
        os.makedirs(root)

    zip_file_path = os.path.join(root, file_name+'.zip')
    repo_url = _get_repo_url()
    if repo_url[-1] != '/':
        repo_url = repo_url + '/'
    download(_url_format.format(repo_url=repo_url, file_name=file_name),
             path=zip_file_path,
             overwrite=True)
    with zipfile.ZipFile(zip_file_path) as zf:
        zf.extractall(root)
    os.remove(zip_file_path)

    if check_sha1(file_path, sha1_hash):
        return _load_vocab_file(file_path)
    else:
        raise ValueError('Downloaded file has different hash. Please try again.')
示例#15
0
    def _get_file_path(cls, source_file_hash, embedding_root, source):
        cls_name = cls.__name__.lower()
        embedding_root = os.path.expanduser(embedding_root)
        url = cls._get_file_url(cls_name, source_file_hash, source)

        embedding_dir = os.path.join(embedding_root, cls_name)

        pretrained_file_name, expected_file_hash = source_file_hash[source]
        pretrained_file_path = os.path.join(embedding_dir, pretrained_file_name)

        if not os.path.exists(pretrained_file_path) \
           or not check_sha1(pretrained_file_path, expected_file_hash):
            print('Embedding file {} is not found. Downloading from Gluon Repository. '
                  'This may take some time.'.format(pretrained_file_name))
            download(url, pretrained_file_path, sha1_hash=expected_file_hash)

        return pretrained_file_path
示例#16
0
    def _get_file_path(cls, source_file_hash, embedding_root, source):
        cls_name = cls.__name__.lower()
        embedding_root = os.path.expanduser(embedding_root)
        url = cls._get_file_url(cls_name, source_file_hash, source)

        embedding_dir = os.path.join(embedding_root, cls_name)

        pretrained_file_name, expected_file_hash = source_file_hash[source]
        pretrained_file_path = os.path.join(embedding_dir, pretrained_file_name)

        if not os.path.exists(pretrained_file_path) \
           or not check_sha1(pretrained_file_path, expected_file_hash):
            print('Embedding file {} is not found. Downloading from Gluon Repository. '
                  'This may take some time.'.format(pretrained_file_name))
            download(url, pretrained_file_path, sha1_hash=expected_file_hash)

        return pretrained_file_path
def download_voc_pascal(data_dir='../data'):
    voc_dir = os.path.join(data_dir, 'VOCdevkit/VOC2012')
    url = ('http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar')
    sha1 = '4e443f8a2eca6b1dac8a6c57641b67dd40621a49'
    fname = gutils.download(url, data_dir, sha1_hash=sha1)
    with tarfile.open(fname, 'r') as f:
        f.extractall(data_dir)
    return voc_dir
示例#18
0
def download_img_labels():
    """ Download an image and imagenet1k class labels for test"""
    img_name = 'cat.png'
    synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
                      '4d0b62f3d01426887599d4f7ede23ee5/raw/',
                      '596b27d23537e5a1b5751d2b0481ef172f58b539/',
                      'imagenet1000_clsid_to_human.txt'])
    synset_name = 'synset.txt'
    download('https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true', img_name)
    download(synset_url, synset_name)

    with open(synset_name) as fin:
        synset = eval(fin.read())

    with open("synset.csv", "w") as fout:
        w = csv.writer(fout)
        w.writerows(synset.items())
示例#19
0
文件: nmt.py 项目: zhould1990/d2l-en
def load_data_nmt(batch_size, max_len, num_examples=1000):
    """Download an NMT dataset, return its vocabulary and data iterator."""

    # Download and preprocess
    def preprocess_raw(text):
        text = text.replace('\u202f', ' ').replace('\xa0', ' ')
        out = ''
        for i, char in enumerate(text.lower()):
            if char in (',', '!', '.') and text[i - 1] != ' ':
                out += ' '
            out += char
        return out

    fname = gutils.download('http://www.manythings.org/anki/fra-eng.zip')
    with zipfile.ZipFile(fname, 'r') as f:
        raw_text = f.read('fra.txt').decode("utf-8")
    text = preprocess_raw(raw_text)

    # Tokenize
    source, target = [], []
    for i, line in enumerate(text.split('\n')):
        if i >= num_examples:
            break
        parts = line.split('\t')
        if len(parts) == 2:
            source.append(parts[0].split(' '))
            target.append(parts[1].split(' '))

    # Build vocab
    def build_vocab(tokens):
        tokens = [token for line in tokens for token in line]
        return Vocab(tokens, min_freq=3, use_special_tokens=True)

    src_vocab, tgt_vocab = build_vocab(source), build_vocab(target)

    # Convert to index arrays
    def pad(line, max_len, padding_token):
        if len(line) > max_len:
            return line[:max_len]
        return line + [padding_token] * (max_len - len(line))

    def build_array(lines, vocab, max_len, is_source):
        lines = [vocab[line] for line in lines]
        if not is_source:
            lines = [[vocab.bos] + line + [vocab.eos] for line in lines]
        array = nd.array([pad(line, max_len, vocab.pad) for line in lines])
        valid_len = (array != vocab.pad).sum(axis=1)
        return array, valid_len

    src_array, src_valid_len = build_array(source, src_vocab, max_len, True)
    tgt_array, tgt_valid_len = build_array(target, tgt_vocab, max_len, False)

    # Construct data iterator
    train_set = gdata.ArrayDataset(src_array, src_valid_len, tgt_array,
                                   tgt_valid_len)
    train_iter = gdata.DataLoader(train_set, batch_size, shuffle=True)

    return src_vocab, tgt_vocab, train_iter
示例#20
0
文件: voc.py 项目: zhould1990/d2l-en
def download_voc_pascal(data_dir='../data'):
    """Download the Pascal VOC2012 Dataset."""
    voc_dir = os.path.join(data_dir, 'VOCdevkit/VOC2012')
    url = 'http://data.mxnet.io/data/VOCtrainval_11-May-2012.tar'
    sha1 = '4e443f8a2eca6b1dac8a6c57641b67dd40621a49'
    fname = gutils.download(url, data_dir, sha1_hash=sha1)
    with tarfile.open(fname, 'r') as f:
        f.extractall(data_dir)
    return voc_dir
示例#21
0
def get_official_squad_eval_script(version='2.0', download_dir=None):
    url_info = {
        '2.0': [
            'evaluate-v2.0.py', 'https://worksheets.codalab.org/rest/bundles/'
            '0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/',
            '5a584f1952c88b4088be5b51f2046a2c337aa706'
        ]
    }
    if version not in url_info:
        raise ValueError('Version {} is not supported'.format(version))
    if download_dir is None:
        download_dir = os.path.realpath(
            os.path.dirname(os.path.realpath(__file__)))
    download_path = os.path.join(download_dir, url_info[version][0])
    download(url_info[version][1],
             download_path,
             sha1_hash=url_info[version][2])
    return download_path
示例#22
0
def test_sentencepiece_tokenizer_subword_regularization():
    url_format = 'https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/dataset/vocab/{}'
    filename = 'test-31c8ed7b.uni'
    download(url_format.format(filename), path=os.path.join('tests', 'data', filename))
    tokenizer = t.SentencepieceTokenizer(os.path.join('tests', 'data', filename),
                                         -1, 0.1)
    detokenizer = t.SentencepieceDetokenizer(os.path.join('tests', 'data', filename))
    text = "Introducing Gluon: An Easy-to-Use Programming Interface for Flexible Deep Learning."
    try:
        reg_ret = [tokenizer(text) for _ in range(10)]
        detext = detokenizer(reg_ret[0])
    except ImportError:
        warnings.warn("Sentencepiece not installed, skip test_sentencepiece_tokenizer().")
        return
    assert text == detext
    assert any(reg_ret[i] != reg_ret[0] for i in range(len(reg_ret)))
    assert all(t in tokenizer.tokens for ret in reg_ret for t in ret)
    assert all(detokenizer(reg_ret[i]) == detext for i in range(len(reg_ret)))
示例#23
0
def download_data(dataset):
    if not os.path.exists(dataset):
        url = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/%s.tar.gz' % (
            dataset)
        os.mkdir(dataset)
        data_file = utils.download(url)
        with tarfile.open(data_file) as tar:
            tar.extractall(path='.')
        os.remove(data_file)
示例#24
0
def _get_xlnet_tokenizer(dataset_name, root):
    assert dataset_name.lower() == '126gb'
    root = os.path.expanduser(root)
    file_path = os.path.join(root, 'xlnet_126gb-871f0b3c.spiece')
    sha1_hash = '871f0b3c13b92fc5aea8fba054a214c420e302fd'
    if os.path.exists(file_path):
        if not check_sha1(file_path, sha1_hash):
            print('Detected mismatch in the content of model tokenizer. Downloading again.')
    else:
        print('Tokenizer file is not found. Downloading.')

    if not os.path.exists(root):
        try:
            os.makedirs(root)
        except OSError as e:
            if e.errno == errno.EEXIST and os.path.isdir(root):
                pass
            else:
                raise e

    repo_url = _get_repo_url()
    prefix = str(time.time())
    zip_file_path = os.path.join(root, prefix + 'xlnet_126gb-871f0b3c.zip')
    if repo_url[-1] != '/':
        repo_url = repo_url + '/'
    download(_url_format.format(repo_url=repo_url, file_name='xlnet_126gb-871f0b3c'),
             path=zip_file_path, overwrite=True)
    with zipfile.ZipFile(zip_file_path) as zf:
        if not os.path.exists(file_path):
            zf.extractall(root)
    try:
        os.remove(zip_file_path)
    except OSError as e:
        # file has already been removed.
        if e.errno == 2:
            pass
        else:
            raise e

    if not check_sha1(file_path, sha1_hash):
        raise ValueError('Downloaded file has different hash. Please try again.')

    tokenizer = XLNetTokenizer(file_path)
    return tokenizer
def download_data(root_folder, dataset):
    data_folder = os.path.join(root_folder, dataset)
    if os.path.exists(data_folder):
        return
    url = f"https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/{dataset}.tar.gz"
    os.makedirs(data_folder, exist_ok=True)
    data_file = utils.download(url)
    with tarfile.open(data_file) as tar:
        tar.extractall(path=data_folder)
    os.remove(data_file)
示例#26
0
def _download_voc_pascal(data_dir='../data'):
    voc_dir = data_dir + '/VOCdevkit/VOC2012'
    url = ('http://host.robots.ox.ac.uk/pascal/VOC/voc2012'
           '/VOCtrainval_11-May-2012.tar')
    sha1 = '4e443f8a2eca6b1dac8a6c57641b67dd40621a49'
    fname = gutils.download(url, data_dir, sha1_hash=sha1)
    if not os.path.exists(voc_dir + '/ImageSets/Segmentation/train.txt'):
        with tarfile.open(fname, 'r') as f:
            f.extractall(data_dir)
    return voc_dir
示例#27
0
文件: voc.py 项目: tsintian/d2l-en
def download_voc_pascal(data_dir='../data'):
    """Download the Pascal VOC2012 Dataset."""
    voc_dir = os.path.join(data_dir, 'VOCdevkit/VOC2012')
    url = ('http://host.robots.ox.ac.uk/pascal/VOC/voc2012'
           '/VOCtrainval_11-May-2012.tar')
    sha1 = '4e443f8a2eca6b1dac8a6c57641b67dd40621a49'
    fname = gutils.download(url, data_dir, sha1_hash=sha1)
    with tarfile.open(fname, 'r') as f:
        f.extractall(data_dir)
    return voc_dir
示例#28
0
def get_workload_inception_v1():
    """ Import Inception V1 workload from frozen protobuf

    Parameters
    ----------
        Nothing.

    Returns
    -------
    (image_data, tvm_data, graph_def) : Tuple
        image_data is raw encoded image data for TF input.
        tvm_data is the decoded image data for TVM input.
        graph_def is the tensorflow workload for Inception V1.

    """

    repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
    model_name = 'classify_image_graph_def-with_shapes.pb'
    model_url = os.path.join(repo_base, model_name)
    image_name = 'elephant-299.jpg'
    image_url = os.path.join(repo_base, image_name)

    from mxnet.gluon.utils import download
    download(model_url, model_name)
    download(image_url, image_name)

    if not tf.gfile.Exists(os.path.join("./", image_name)):
        tf.logging.fatal('File does not exist %s', image)
    image_data = tf.gfile.FastGFile(os.path.join("./", image_name),
                                    'rb').read()

    # TVM doesn't handle decode, hence decode it.
    from PIL import Image
    tvm_data = Image.open(os.path.join("./", image_name)).resize((299, 299))
    tvm_data = np.array(tvm_data)

    # Creates graph from saved graph_def.pb.
    with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        graph = tf.import_graph_def(graph_def, name='')
        return (image_data, tvm_data, graph_def)
示例#29
0
文件: nmt.py 项目: tsintian/d2l-en
def load_data_nmt(batch_size, max_len, num_examples=1000):
    """Download an NMT dataset, return its vocabulary and data iterator."""
    # Download and preprocess
    def preprocess_raw(text):
        text = text.replace('\u202f', ' ').replace('\xa0', ' ')
        out = ''
        for i, char in enumerate(text.lower()):
            if char in (',', '!', '.') and text[i-1] != ' ':
                out += ' '
            out += char
        return out
    fname = gutils.download('http://www.manythings.org/anki/fra-eng.zip')
    with zipfile.ZipFile(fname, 'r') as f:
        raw_text = f.read('fra.txt').decode("utf-8")
    text = preprocess_raw(raw_text)

    # Tokenize
    source, target = [], []
    for i, line in enumerate(text.split('\n')):
        if i >= num_examples:
            break
        parts = line.split('\t')
        if len(parts) == 2:
            source.append(parts[0].split(' '))
            target.append(parts[1].split(' '))

    # Build vocab
    def build_vocab(tokens):
        tokens = [token for line in tokens for token in line]
        return Vocab(tokens, min_freq=3, use_special_tokens=True)
    src_vocab, tgt_vocab = build_vocab(source), build_vocab(target)

    # Convert to index arrays
    def pad(line, max_len, padding_token):
        if len(line) > max_len:
            return line[:max_len]
        return line + [padding_token] * (max_len - len(line))

    def build_array(lines, vocab, max_len, is_source):
        lines = [vocab[line] for line in lines]
        if not is_source:
            lines = [[vocab.bos] + line + [vocab.eos] for line in lines]
        array = nd.array([pad(line, max_len, vocab.pad) for line in lines])
        valid_len = (array != vocab.pad).sum(axis=1)
        return array, valid_len

    src_array, src_valid_len = build_array(source, src_vocab, max_len, True)
    tgt_array, tgt_valid_len = build_array(target, tgt_vocab, max_len, False)

    # Construct data iterator
    train_set = gdata.ArrayDataset(src_array, src_valid_len, tgt_array, tgt_valid_len)
    train_iter = gdata.DataLoader(train_set, batch_size, shuffle=True)

    return src_vocab, tgt_vocab, train_iter
    def _get_data(self):
        archive_file_name, archive_hash = self.archive_file
        data_file_name, data_hash = self.data_file[self._segment]
        root = self._root
        path = os.path.join(root, data_file_name)
        if not os.path.exists(path) or not check_sha1(path, data_hash):
            downloaded_file_path = download(self.url + archive_file_name,
                                            path=root, sha1_hash=archive_hash)

            with zipfile.ZipFile(downloaded_file_path, 'r') as zf:
                zf.extractall(root)
        return path
示例#31
0
def test_sharded_data_loader_record_file():
    # test record file
    url_format = 'https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/dataset/pikachu/{}'
    filename = 'val.rec'
    idx_filename = 'val.idx'
    download(url_format.format(filename), path=os.path.join('tests', 'data', filename))
    download(url_format.format(idx_filename), path=os.path.join('tests', 'data', idx_filename))
    rec_dataset = gluon.data.vision.ImageRecordDataset(os.path.join('tests', 'data', filename))

    num_workers = 2
    num_shards = 4
    X = np.random.uniform(size=(100, 20))
    Y = np.random.uniform(size=(100,))
    batch_sampler = FixedBucketSampler(lengths=[X.shape[1]] * X.shape[0],
                                       batch_size=2,
                                       num_buckets=1,
                                       shuffle=False,
                                       num_shards=num_shards)
    loader = ShardedDataLoader(rec_dataset, batch_sampler=batch_sampler, num_workers=num_workers)
    for i, seqs in enumerate(loader):
        assert len(seqs) == num_shards
示例#32
0
    def _get_data(self):
        if any(not os.path.exists(path) or not check_sha1(path, sha1) for path, sha1 in
               ((os.path.join(self._root, name), sha1) for _, name, sha1 in self._train_data + self._test_data)):
            for url, _, sha1 in self._train_data + self._test_data:
                download(url=url, path=self._root, sha1_hash=sha1)

        if self._mode == "train":
            data_files = self._train_data[0]
        else:
            data_files = self._test_data[0]

        import scipy.io as sio

        loaded_mat = sio.loadmat(os.path.join(self._root, data_files[1]))

        data = loaded_mat["X"]
        data = np.transpose(data, (3, 0, 1, 2))
        self._data = mx.nd.array(data, dtype=data.dtype)

        self._label = loaded_mat["y"].astype(np.int32).squeeze()
        np.place(self._label, self._label == 10, 0)
示例#33
0
    def _get_data(archive_file, data_file, segment, root, namespace):
        archive_file_name, archive_hash = archive_file
        data_file_name, data_hash = data_file[segment]
        path = os.path.join(root, data_file_name)
        if not os.path.exists(path) or not check_sha1(path, data_hash):
            downloaded_file_path = download(_get_repo_file_url(
                namespace, archive_file_name),
                                            path=root,
                                            sha1_hash=archive_hash)

            with zipfile.ZipFile(downloaded_file_path, 'r') as zf:
                zf.extractall(root)
        return path
示例#34
0
def test_pretrained_gpt2(model_name, tmp_path):
    sentence = ' natural language processing tools such as gluonnlp and torchtext'
    model, vocab = get_model(model_name, dataset_name='openai_webtext')
    tokenizer = GPT2BPETokenizer()
    detokenizer = GPT2BPEDetokenizer()
    true_data_hash = {'gpt2_117m': '29526682508d03a7c54c598e889f77f7b4608df0',
                      'gpt2_345m': '6680fd2a3d7b737855536f480bc19d166f15a3ad'}
    file_name = '{model_name}_gt_logits-{short_hash}.npy'.format(
            model_name=model_name,
            short_hash=true_data_hash[model_name][:8])
    url_format = '{repo_url}gluon/dataset/test/{file_name}'
    repo_url = _get_repo_url()
    path = os.path.join(str(tmp_path), file_name)
    download(url_format.format(repo_url=repo_url, file_name=file_name),
             path=path,
             sha1_hash=true_data_hash[model_name])
    gt_logits = np.load(path)
    model.hybridize()
    indices = vocab[tokenizer(sentence)]
    nd_indices = mx.nd.expand_dims(mx.nd.array(indices), axis=0)
    logits, new_states = model(nd_indices, None)
    npt.assert_allclose(logits.asnumpy(), gt_logits, 1E-5, 1E-5)
示例#35
0
def build(target_dir):
	""" Compiles resnet18 with TVM"""

	# download the pretrained resnet18 trained on imagenet1k dataset for
	# image classification task
	block = get_model('resnet18_v1', pretrained=True)
	
	sym, params = nnvm.frontend.from_mxnet(block)
	# add the softmax layer for prediction
	net = nnvm.sym.softmax(sym)
	# compile the model
	with nnvm.compiler.build_config(opt_level=opt_level):
		graph, lib, params = nnvm.compiler.build(
			net, target, shape={"data": data_shape}, params=params)
	# same the model artifacts
	lib.save(os.path.join(target_dir, "deploy_lib.o"))
	cc.create_shared(os.path.join(target_dir, "deploy_lib.so"),
    				[os.path.join(target_dir, "deploy_lib.o")])
	
	with open(os.path.join(target_dir, "deploy_graph.json"), "w") as fo:
	    fo.write(graph.json())
	with open(os.path.join(target_dir,"deploy_param.params"), "wb") as fo:
	    fo.write(nnvm.compiler.save_param_dict(params))
	# download an image and imagenet1k class labels for test
	img_name = 'cat.png'
	synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
                      '4d0b62f3d01426887599d4f7ede23ee5/raw/',
                      '596b27d23537e5a1b5751d2b0481ef172f58b539/',
                      'imagenet1000_clsid_to_human.txt'])
	synset_name = 'synset.txt'
	download('https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true', img_name)
	download(synset_url, synset_name)

	with open(synset_name) as fin:
		synset = eval(fin.read())

	with open("synset.csv", "w") as fout:
		w = csv.writer(fout)
		w.writerows(synset.items())
示例#36
0
    def _get_vocab(self):
        archive_file_name, archive_hash = self._archive_vocab
        vocab_file_name, vocab_hash = self._vocab_file
        namespace = 'gluon/dataset/vocab'
        root = self._root
        path = os.path.join(root, vocab_file_name)
        if not os.path.exists(path) or not check_sha1(path, vocab_hash):
            downloaded_path = download(_get_repo_file_url(namespace, archive_file_name),
                                       path=root, sha1_hash=archive_hash)

            with zipfile.ZipFile(downloaded_path, 'r') as zf:
                zf.extractall(path=root)
        return path
示例#37
0
def _download_vocab_tokenizer(root, file_name, file_ext, file_path):
    utils.mkdir(root)

    temp_num = str(random.Random().randint(1, sys.maxsize))
    temp_root = os.path.join(root, temp_num)
    temp_file_path = os.path.join(temp_root, file_name + file_ext)
    temp_zip_file_path = os.path.join(temp_root,
                                      temp_num + '_' + file_name + '.zip')

    repo_url = _get_repo_url()
    download(_url_format.format(repo_url=repo_url, file_name=file_name),
             path=temp_zip_file_path,
             overwrite=True)
    with zipfile.ZipFile(temp_zip_file_path) as zf:
        assert file_name + file_ext in zf.namelist(
        ), '{} not part of {}. Only have: {}'.format(file_name + file_ext,
                                                     file_name + '.zip',
                                                     zf.namelist())
        utils.mkdir(temp_root)
        zf.extractall(temp_root)
        os.replace(temp_file_path, file_path)
        shutil.rmtree(temp_root)
示例#38
0
def download_img_labels():
    """ Download an image and imagenet1k class labels for test"""
    from mxnet.gluon.utils import download

    img_name = "cat.png"
    synset_url = "".join([
        "https://gist.githubusercontent.com/zhreshold/",
        "4d0b62f3d01426887599d4f7ede23ee5/raw/",
        "596b27d23537e5a1b5751d2b0481ef172f58b539/",
        "imagenet1000_clsid_to_human.txt",
    ])
    synset_name = "synset.txt"
    download(
        "https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true",
        img_name)
    download(synset_url, synset_name)

    with open(synset_name) as fin:
        synset = eval(fin.read())

    with open("synset.csv", "w") as fout:
        w = csv.writer(fout)
        w.writerows(synset.items())
示例#39
0
文件: tf.py 项目: LANHUIYING/tvm
def get_workload(model_path, model_sub_path=None):
    """ Import workload from frozen protobuf

    Parameters
    ----------
    model_path: str
        model_path on remote repository to download from.

    model_sub_path: str
        Model path in the compressed archive.

    Returns
    -------
    graph_def: graphdef
        graph_def is the tensorflow workload for mobilenet.

    """

    temp = util.tempdir()
    if model_sub_path:
        path_model = get_workload_official(model_path, model_sub_path, temp)
    else:
        repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/'
        model_name = os.path.basename(model_path)
        model_url = os.path.join(repo_base, model_path)

        from mxnet.gluon.utils import download
        path_model = temp.relpath(model_name)
        download(model_url, path_model)

    # Creates graph from saved graph_def.pb.
    with tf.gfile.FastGFile(path_model, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        graph = tf.import_graph_def(graph_def, name='')
        temp.remove()
        return graph_def
示例#40
0
文件: tf.py 项目: LANHUIYING/tvm
def get_workload_official(model_url, model_sub_path, temp_dir):
    """ Import workload from tensorflow official

    Parameters
    ----------
    model_url: str
        URL from where it will be downloaded.

    model_sub_path:
        Sub path in extracted tar for the ftozen protobuf file.

    temp_dir: TempDirectory
        The temporary directory object to download the content.

    Returns
    -------
    graph_def: graphdef
        graph_def is the tensorflow workload for mobilenet.

    """

    model_tar_name = os.path.basename(model_url)

    from mxnet.gluon.utils import download
    temp_path = temp_dir.relpath("./")
    path_model = temp_path + model_tar_name

    download(model_url, path_model)

    import tarfile
    if path_model.endswith("tgz") or path_model.endswith("gz"):
        tar = tarfile.open(path_model)
        tar.extractall(path=temp_path)
        tar.close()
    else:
        raise RuntimeError('Could not decompress the file: ' + path_model)
    return temp_path + model_sub_path
示例#41
0
文件: tf.py 项目: zheng-xq/tvm
def get_workload(model_path, model_sub_path=None):
    """ Import workload from frozen protobuf

    Parameters
    ----------
    model_path: str
        model_path on remote repository to download from.

    model_sub_path: str
        Model path in the compressed archive.

    Returns
    -------
    graph_def: graphdef
        graph_def is the tensorflow workload for mobilenet.

    """

    temp = util.tempdir()
    if model_sub_path:
        path_model = get_workload_official(model_path, model_sub_path, temp)
    else:
        repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/'
        model_name = os.path.basename(model_path)
        model_url = os.path.join(repo_base, model_path)

        from mxnet.gluon.utils import download
        path_model = temp.relpath(model_name)
        download(model_url, path_model)

    # Creates graph from saved graph_def.pb.
    with tf.gfile.FastGFile(path_model, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        graph = tf.import_graph_def(graph_def, name='')
        temp.remove()
        return graph_def
示例#42
0
文件: tf.py 项目: zheng-xq/tvm
def get_workload_official(model_url, model_sub_path, temp_dir):
    """ Import workload from tensorflow official

    Parameters
    ----------
    model_url: str
        URL from where it will be downloaded.

    model_sub_path:
        Sub path in extracted tar for the ftozen protobuf file.

    temp_dir: TempDirectory
        The temporary directory object to download the content.

    Returns
    -------
    graph_def: graphdef
        graph_def is the tensorflow workload for mobilenet.

    """

    model_tar_name = os.path.basename(model_url)

    from mxnet.gluon.utils import download
    temp_path = temp_dir.relpath("./")
    path_model = temp_path + model_tar_name

    download(model_url, path_model)

    import tarfile
    if path_model.endswith("tgz") or path_model.endswith("gz"):
        tar = tarfile.open(path_model)
        tar.extractall(path=temp_path)
        tar.close()
    else:
        raise RuntimeError('Could not decompress the file: ' + path_model)
    return temp_path + model_sub_path
示例#43
0
文件: tf.py 项目: zhangqiaorjc/tvm-1
def get_workload_inception_v3():
    """ Import Inception V3 workload from frozen protobuf

    Parameters
    ----------
        Nothing.

    Returns
    -------
    (normalized, graph_def) : Tuple
        normalized is normalized input for graph testing.
        graph_def is the tensorflow workload for Inception V3.
    """

    repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV3/'
    model_path = 'InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb'

    image_name = 'elephant-299.jpg'
    image_url = os.path.join(repo_base, image_name)
    from mxnet.gluon.utils import download
    download(image_url, image_name)
    normalized = read_normalized_tensor_from_image_file(os.path.join("./", image_name))

    return (normalized, get_workload(model_path))
示例#44
0
    def _get_data(self):
        """Load data from the file. Do nothing if data was loaded before.
        """
        (data_archive_name, archive_hash), (data_name, data_hash) \
            = self._data_file()[self._segment]
        data_path = os.path.join(self._root, data_name)

        if not os.path.exists(data_path) or not check_sha1(data_path, data_hash):
            file_path = download(_get_repo_file_url(self._repo_dir(), data_archive_name),
                                 path=self._root, sha1_hash=archive_hash)

            with zipfile.ZipFile(file_path, 'r') as zf:
                for member in zf.namelist():
                    filename = os.path.basename(member)
                    if filename:
                        dest = os.path.join(self._root, filename)
                        with zf.open(member) as source, open(dest, 'wb') as target:
                            shutil.copyfileobj(source, target)
    def _get_data(self):
        """Load data from the file. Does nothing if data was loaded before
        """
        data_archive_name, _, data_hash = self._data_file[self._segment]
        path = os.path.join(self._root, data_archive_name)

        if not os.path.exists(path) or not check_sha1(path, data_hash):
            file_path = download(_get_repo_file_url('gluon/dataset/squad', data_archive_name),
                                 path=self._root, sha1_hash=data_hash)

            with zipfile.ZipFile(file_path, 'r') as zf:
                for member in zf.namelist():
                    filename = os.path.basename(member)

                    if filename:
                        dest = os.path.join(self._root, filename)

                        with zf.open(member) as source, open(dest, 'wb') as target:
                            shutil.copyfileobj(source, target)
示例#46
0
    def _get_data(self):
        archive_file_name, archive_hash = self._archive_file
        data_file_name, data_hash = self._data_file[self._segment]
        root = self._root
        path = os.path.join(root, data_file_name)
        if not os.path.exists(path) or not check_sha1(path, data_hash):
            downloaded_file_path = download(_get_repo_file_url(self._namespace, archive_file_name),
                                            path=root,
                                            sha1_hash=archive_hash)

            with zipfile.ZipFile(downloaded_file_path, 'r') as zf:
                for member in zf.namelist():
                    filename = os.path.basename(member)
                    if filename:
                        dest = os.path.join(root, filename)
                        with zf.open(member) as source, \
                             open(dest, 'wb') as target:
                            shutil.copyfileobj(source, target)
        return path
示例#47
0
文件: imdb.py 项目: tsintian/d2l-en
def load_data_imdb(batch_size, max_len=500):
    """Download an IMDB dataset, return the vocabulary and iterators."""

    data_dir = '../data'
    url = 'http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz'
    fname = gutils.download(url, data_dir)
    with tarfile.open(fname, 'r') as f:
        f.extractall(data_dir)

    def read_imdb(folder='train'):
        data, labels = [], []
        for label in ['pos', 'neg']:
            folder_name = os.path.join(data_dir, 'aclImdb', folder, label)
            for file in os.listdir(folder_name):
                with open(os.path.join(folder_name, file), 'rb') as f:
                    review = f.read().decode('utf-8').replace('\n', '')
                    data.append(review)
                    labels.append(1 if label == 'pos' else 0)
        return data, labels

    train_data, test_data = read_imdb('train'), read_imdb('test')

    def tokenize(sentences):
        return [line.split(' ') for line in sentences]

    train_tokens = tokenize(train_data[0])
    test_tokens = tokenize(test_data[0])

    vocab = Vocab([tk for line in train_tokens for tk in line], min_freq=5)

    def pad(x):
        return x[:max_len] if len(x) > max_len else x + [vocab.unk] * (max_len - len(x))

    train_features = nd.array([pad(vocab[line]) for line in train_tokens])
    test_features = nd.array([pad(vocab[line]) for line in test_tokens])

    train_set = gdata.ArrayDataset(train_features, train_data[1])
    test_set = gdata.ArrayDataset(test_features, test_data[1])
    train_iter = gdata.DataLoader(train_set, batch_size, shuffle=True)
    test_iter = gdata.DataLoader(test_set, batch_size)

    return vocab, train_iter, test_iter
示例#48
0
MODEL_NAME = 'rnn'
#Seed value
seed = 'Thus'
#Number of characters to predict
num = 1000

# Download required files
# -----------------------
# Download cfg and weights file if first time.
CFG_NAME = MODEL_NAME + '.cfg'
WEIGHTS_NAME = MODEL_NAME + '.weights'
REPO_URL = 'https://github.com/dmlc/web-data/blob/master/darknet/'
CFG_URL = REPO_URL + 'cfg/' + CFG_NAME + '?raw=true'
WEIGHTS_URL = REPO_URL + 'weights/' + WEIGHTS_NAME + '?raw=true'

download(CFG_URL, CFG_NAME)
download(WEIGHTS_URL, WEIGHTS_NAME)

# Download and Load darknet library
DARKNET_LIB = 'libdarknet.so'
DARKNET_URL = REPO_URL + 'lib/' + DARKNET_LIB + '?raw=true'
download(DARKNET_URL, DARKNET_LIB)
DARKNET_LIB = __darknetffi__.dlopen('./' + DARKNET_LIB)
cfg = "./" + str(CFG_NAME)
weights = "./" + str(WEIGHTS_NAME)
net = DARKNET_LIB.load_network(cfg.encode('utf-8'), weights.encode('utf-8'), 0)
dtype = 'float32'
batch_size = 1

# Import the graph to NNVM
# ------------------------
示例#49
0
#target = 'cuda'
#target_host = 'llvm'
#layout = "NCHW"
#ctx = tvm.gpu(0)
target = 'llvm'
target_host = 'llvm'
layout = None
ctx = tvm.cpu(0)

######################################################################
# Download required files
# -----------------------
# Download files listed above.
from mxnet.gluon.utils import download

download(image_url, img_name)
download(model_url, model_name)
download(map_proto_url, map_proto)
download(lable_map_url, lable_map)

######################################################################
# Import model
# ------------
# Creates tensorflow graph definition from protobuf file.

with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    graph = tf.import_graph_def(graph_def, name='')
    # Call the utility to import the graph definition into default graph.
    graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
# ResNet model from model zoo, which is pretrained on ImageNet. You
# can found more details about this part at `Compile MXNet Models`

from mxnet.gluon.model_zoo.vision import get_model
from mxnet.gluon.utils import download
from PIL import Image
import numpy as np

# only one line to get the model
block = get_model('resnet18_v1', pretrained=True)

######################################################################
# In order to test our model, here we download an image of cat and
# transform its format.
img_name = 'cat.jpg'
download('https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true', img_name)
image = Image.open(img_name).resize((224, 224))

def transform_image(image):
    image = np.array(image) - np.array([123., 117., 104.])
    image /= np.array([58.395, 57.12, 57.375])
    image = image.transpose((2, 0, 1))
    image = image[np.newaxis, :]
    return image

x = transform_image(image)


######################################################################
# synset is used to transform the label from number of ImageNet class to
# the word human can understand.