Ejemplo n.º 1
0
def import_onnx():
    """Import the onnx model into mxnet"""
    model_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/super_resolution.onnx'
    download(model_url, 'super_resolution.onnx')

    LOGGER.info("Converting onnx format to mxnet's symbol and params...")
    sym, arg_params, aux_params = onnx_mxnet.import_model('super_resolution.onnx')
    LOGGER.info("Successfully Converted onnx format to mxnet's symbol and params...")
    return sym, arg_params, aux_params
Ejemplo n.º 2
0
def get_test_image():
    """Download and process the test image"""
    # Load test image
    input_image_dim = 224
    img_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/super_res_input.jpg'
    download(img_url, 'super_res_input.jpg')
    img = Image.open('super_res_input.jpg').resize((input_image_dim, input_image_dim))
    img_ycbcr = img.convert("YCbCr")
    img_y, img_cb, img_cr = img_ycbcr.split()
    input_image = np.array(img_y)[np.newaxis, np.newaxis, :, :]
    return input_image, img_cb, img_cr
Ejemplo n.º 3
0
def download_sick(dirpath):
    if os.path.exists(dirpath):
        print('Found SICK dataset - skip')
        return
    else:
        os.makedirs(dirpath)
    train_url = 'http://alt.qcri.org/semeval2014/task1/data/uploads/sick_train.zip'
    trial_url = 'http://alt.qcri.org/semeval2014/task1/data/uploads/sick_trial.zip'
    test_url = 'http://alt.qcri.org/semeval2014/task1/data/uploads/sick_test_annotated.zip'
    unzip(download(train_url, dirname=dirpath))
    unzip(download(trial_url, dirname=dirpath))
    unzip(download(test_url, dirname=dirpath))
Ejemplo n.º 4
0
def download_sick(dirpath):
    if os.path.exists(dirpath):
        print('Found SICK dataset - skip')
        return
    else:
        os.makedirs(dirpath)
    train_url = 'http://alt.qcri.org/semeval2014/task1/data/uploads/sick_train.zip'
    trial_url = 'http://alt.qcri.org/semeval2014/task1/data/uploads/sick_trial.zip'
    test_url = 'http://alt.qcri.org/semeval2014/task1/data/uploads/sick_test_annotated.zip'
    unzip(download(train_url, dirname=dirpath))
    unzip(download(trial_url, dirname=dirpath))
    unzip(download(test_url, dirname=dirpath))
Ejemplo n.º 5
0
def get_test_image():
    """Download and process the test image"""
    # Load test image
    input_image_dim = 224
    img_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/super_res_input.jpg'
    download(img_url, 'super_res_input.jpg')
    img = Image.open('super_res_input.jpg').resize(
        (input_image_dim, input_image_dim))
    img_ycbcr = img.convert("YCbCr")
    img_y, img_cb, img_cr = img_ycbcr.split()
    input_image = np.array(img_y)[np.newaxis, np.newaxis, :, :]
    return input_image, img_cb, img_cr
Ejemplo n.º 6
0
def get_Dataset():

    url_format = 'https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/dataset/pikachu/{}'
    hashes = {
        'train.rec': 'e6bcb6ffba1ac04ff8a9b1115e650af56ee969c8',
        'train.idx': 'dcf7318b2602c06428b9988470c731621716c393',
        'val.rec': 'd6c33f799b4d058e82f2cb5bd9a976f69d72d520'
    }
    for k, v in hashes.items():
        fname = k
        target = osp.join('data', fname)
        url = url_format.format(k)
        if not osp.exists(target) or not verified(target, v):
            print('Downloading', target, url)
            download(url, fname=fname, dirname='data', overwrite=True)
def get_test_files(name):
    """Extract tar file and returns model path and input, output data"""
    tar_name = download(URLS.get(name), dirname=CURR_PATH.__str__())
    # extract tar file
    tar_path = os.path.join(CURR_PATH, tar_name)
    tar = tarfile.open(tar_path.__str__(), "r:*")
    tar.extractall(path=CURR_PATH.__str__())
    tar.close()
    data_dir = os.path.join(CURR_PATH, name)
    model_path = os.path.join(data_dir, 'model.onnx')

    inputs = []
    outputs = []
    # get test files
    for test_file in os.listdir(data_dir):
        case_dir = os.path.join(data_dir, test_file)
        # skip the non-dir files
        if not os.path.isdir(case_dir):
            continue
        input_file = os.path.join(case_dir, 'input_0.pb')
        input_tensor = TensorProto()
        with open(input_file, 'rb') as proto_file:
            input_tensor.ParseFromString(proto_file.read())
        inputs.append(numpy_helper.to_array(input_tensor))

        output_tensor = TensorProto()
        output_file = os.path.join(case_dir, 'output_0.pb')
        with open(output_file, 'rb') as proto_file:
            output_tensor.ParseFromString(proto_file.read())
        outputs.append(numpy_helper.to_array(output_tensor))

    return model_path, inputs, outputs
Ejemplo n.º 8
0
def get_mnist_iterator(rank):
    data_dir = "data-%d" % rank
    if not os.path.isdir(data_dir):
        os.makedirs(data_dir)
    zip_file_path = download('http://data.mxnet.io/mxnet/data/mnist.zip',
                             dirname=data_dir)
    with zipfile.ZipFile(zip_file_path) as zf:
        zf.extractall(data_dir)

    input_shape = (1, 28, 28)
    batch_size = args.batch_size

    train_iter = mx.io.MNISTIter(
        image="%s/train-images-idx3-ubyte" % data_dir,
        label="%s/train-labels-idx1-ubyte" % data_dir,
        input_shape=input_shape,
        batch_size=batch_size,
        shuffle=True,
        flat=False,
        num_parts=hvd.size(),
        part_index=hvd.rank()
    )

    val_iter = mx.io.MNISTIter(
        image="%s/t10k-images-idx3-ubyte" % data_dir,
        label="%s/t10k-labels-idx1-ubyte" % data_dir,
        input_shape=input_shape,
        batch_size=batch_size,
        flat=False,
        num_parts=hvd.size(),
        part_index=hvd.rank()
    )

    return train_iter, val_iter
Ejemplo n.º 9
0
def get_mnist_iterator(rank):
    data_dir = "data-%d" % rank
    if not os.path.isdir(data_dir):
        os.makedirs(data_dir)
    zip_file_path = download('http://data.mxnet.io/mxnet/data/mnist.zip',
                             dirname=data_dir)
    with zipfile.ZipFile(zip_file_path) as zf:
        zf.extractall(data_dir)

    input_shape = (1, 28, 28)
    batch_size = args.batch_size

    train_iter = mx.io.MNISTIter(
        image="%s/train-images-idx3-ubyte" % data_dir,
        label="%s/train-labels-idx1-ubyte" % data_dir,
        input_shape=input_shape,
        batch_size=batch_size,
        shuffle=True,
        flat=False,
        num_parts=hvd.size(),
        part_index=hvd.rank()
    )

    val_iter = mx.io.MNISTIter(
        image="%s/t10k-images-idx3-ubyte" % data_dir,
        label="%s/t10k-labels-idx1-ubyte" % data_dir,
        input_shape=input_shape,
        batch_size=batch_size,
        flat=False,
    )

    return train_iter, val_iter
Ejemplo n.º 10
0
def get_test_files(name):
    """Extract tar file and returns model path and input, output data"""
    tar_name = download(URLS.get(name), dirname=CURR_PATH.__str__())
    # extract tar file
    tar_path = os.path.join(CURR_PATH, tar_name)
    tar = tarfile.open(tar_path.__str__(), "r:*")
    tar.extractall(path=CURR_PATH.__str__())
    tar.close()
    data_dir = os.path.join(CURR_PATH, name)
    model_path = os.path.join(data_dir, 'model.onnx')

    inputs = []
    outputs = []
    # get test files
    for test_file in os.listdir(data_dir):
        case_dir = os.path.join(data_dir, test_file)
        # skip the non-dir files
        if not os.path.isdir(case_dir):
            continue
        input_file = os.path.join(case_dir, 'input_0.pb')
        input_tensor = TensorProto()
        with open(input_file, 'rb') as proto_file:
            input_tensor.ParseFromString(proto_file.read())
        inputs.append(numpy_helper.to_array(input_tensor))

        output_tensor = TensorProto()
        output_file = os.path.join(case_dir, 'output_0.pb')
        with open(output_file, 'rb') as proto_file:
            output_tensor.ParseFromString(proto_file.read())
        outputs.append(numpy_helper.to_array(output_tensor))

    return model_path, inputs, outputs
Ejemplo n.º 11
0
def get_dataset(prefetch=False):
    image_path = os.path.join(dataset_path, "BSDS300/images")

    if not os.path.exists(image_path):
        os.makedirs(dataset_path)
        file_name = download(dataset_url)
        with tarfile.open(file_name) as tar:
            for item in tar:
                tar.extract(item, dataset_path)
        os.remove(file_name)

    crop_size = 256
    crop_size -= crop_size % upscale_factor
    input_crop_size = crop_size // upscale_factor

    input_transform = [
        CenterCropAug((crop_size, crop_size)),
        ResizeAug(input_crop_size)
    ]
    target_transform = [CenterCropAug((crop_size, crop_size))]

    iters = (ImagePairIter(os.path.join(image_path, "train"),
                           (input_crop_size, input_crop_size),
                           (crop_size, crop_size), batch_size, color_flag,
                           input_transform, target_transform),
             ImagePairIter(os.path.join(image_path, "test"),
                           (input_crop_size, input_crop_size),
                           (crop_size, crop_size), test_batch_size, color_flag,
                           input_transform, target_transform))

    return [PrefetchingIter(i) for i in iters] if prefetch else iters
Ejemplo n.º 12
0
def get_dataset(prefetch=False):
    image_path = os.path.join(dataset_path, "BSDS300/images")

    if not os.path.exists(image_path):
        os.makedirs(dataset_path)
        file_name = download(dataset_url)
        with tarfile.open(file_name) as tar:
            for item in tar:
                tar.extract(item, dataset_path)
        os.remove(file_name)

    crop_size = 256
    crop_size -= crop_size % upscale_factor
    input_crop_size = crop_size // upscale_factor

    input_transform = [CenterCropAug((crop_size, crop_size)), ResizeAug(input_crop_size)]
    target_transform = [CenterCropAug((crop_size, crop_size))]

    iters = (ImagePairIter(os.path.join(image_path, "train"),
                           (input_crop_size, input_crop_size),
                           (crop_size, crop_size),
                           batch_size, color_flag, input_transform, target_transform),
             ImagePairIter(os.path.join(image_path, "test"),
                           (input_crop_size, input_crop_size),
                           (crop_size, crop_size),
                           test_batch_size, color_flag,
                           input_transform, target_transform))

    return [PrefetchingIter(i) for i in iters] if prefetch else iters
Ejemplo n.º 13
0
def download_wordvecs(dirpath):
    if os.path.exists(dirpath):
        print('Found Glove vectors - skip')
        return
    else:
        os.makedirs(dirpath)
    url = 'http://www-nlp.stanford.edu/data/glove.840B.300d.zip'
    unzip(download(url, dirname=dirpath))
Ejemplo n.º 14
0
def download_wordvecs(dirpath):
    if os.path.exists(dirpath):
        print('Found Glove vectors - skip')
        return
    else:
        os.makedirs(dirpath)
    url = 'http://www-nlp.stanford.edu/data/glove.840B.300d.zip'
    unzip(download(url, dirname=dirpath))
Ejemplo n.º 15
0
def download_model(model_name, model_path):
    # reference: https://github.com/mlperf/inference/tree/master/v0.5/classification_and_detection
    if model_name == 'resnet50-v1.5':
        model_url = 'https://zenodo.org/record/2592612/files/resnet50_v1.onnx'
        data_shape = (1, 3, 224, 224)
    else:
        raise ValueError('Model: {} not implemented.'.format(model_name))

    if not os.path.exists(model_path):
        os.mkdir(model_path)
    onnx_model_file = os.path.join(model_path, model_name + '.onnx')
    print(onnx_model_file)

    if not os.path.exists(onnx_model_file):
        print("Downloading ONNX model from: {}".format(model_url))
        download(model_url, onnx_model_file)

    return onnx_model_file, data_shape
Ejemplo n.º 16
0
def test_nodims_import():
    # Download test model without dims mentioned in params
    test_model = download(test_model_path, dirname=CURR_PATH.__str__())
    input_data = np.array([0.2, 0.5])
    nd_data = mx.nd.array(input_data).expand_dims(0)
    sym, arg_params, aux_params = onnx_mxnet.import_model(test_model)
    model_metadata = onnx_mxnet.get_model_metadata(test_model)
    input_names = [inputs[0] for inputs in model_metadata.get('input_tensor_data')]
    output_data = forward_pass(sym, arg_params, aux_params, input_names, nd_data)
    assert(output_data.shape == (1,1))
Ejemplo n.º 17
0
def GetMNIST_ubyte():
    if not os.path.isdir("data"):
        os.makedirs('data')
    if (not os.path.exists('data/train-images-idx3-ubyte')) or \
       (not os.path.exists('data/train-labels-idx1-ubyte')) or \
       (not os.path.exists('data/t10k-images-idx3-ubyte')) or \
       (not os.path.exists('data/t10k-labels-idx1-ubyte')):
        zip_file_path = download('http://data.mxnet.io/mxnet/data/mnist.zip',
                                 dirname='data')
        with zipfile.ZipFile(zip_file_path) as zf:
            zf.extractall('data')
Ejemplo n.º 18
0
def GetCifar10():
    if not os.path.isdir("data"):
        os.makedirs('data')
    if (not os.path.exists('data/cifar/train.rec')) or \
       (not os.path.exists('data/cifar/test.rec')) or \
       (not os.path.exists('data/cifar/train.lst')) or \
       (not os.path.exists('data/cifar/test.lst')):
        zip_file_path = download('http://data.mxnet.io/mxnet/data/cifar10.zip',
                                 dirname='data')
        with zipfile.ZipFile(zip_file_path) as zf:
            zf.extractall('data')
Ejemplo n.º 19
0
def GetCifar10():
    if not os.path.isdir("data"):
        os.makedirs('data')
    if (not os.path.exists('data/cifar/train.rec')) or \
       (not os.path.exists('data/cifar/test.rec')) or \
       (not os.path.exists('data/cifar/train.lst')) or \
       (not os.path.exists('data/cifar/test.lst')):
        zip_file_path = download('http://data.mxnet.io/mxnet/data/cifar10.zip',
                                 dirname='data')
        with zipfile.ZipFile(zip_file_path) as zf:
            zf.extractall('data')
Ejemplo n.º 20
0
def GetMNIST_ubyte():
    if not os.path.isdir("data"):
        os.makedirs('data')
    if (not os.path.exists('data/train-images-idx3-ubyte')) or \
       (not os.path.exists('data/train-labels-idx1-ubyte')) or \
       (not os.path.exists('data/t10k-images-idx3-ubyte')) or \
       (not os.path.exists('data/t10k-labels-idx1-ubyte')):
        zip_file_path = download('http://data.mxnet.io/mxnet/data/mnist.zip',
                                 dirname='data')
        with zipfile.ZipFile(zip_file_path) as zf:
            zf.extractall('data')
Ejemplo n.º 21
0
def download_parser(dirpath):
    parser_dir = 'stanford-parser'
    if os.path.exists(os.path.join(dirpath, parser_dir)):
        print('Found Stanford Parser - skip')
        return
    url = 'http://nlp.stanford.edu/software/stanford-parser-full-2015-01-29.zip'
    filepath = download(url, dirname=dirpath)
    zip_dir = ''
    with zipfile.ZipFile(filepath) as zf:
        zip_dir = zf.namelist()[0]
        zf.extractall(dirpath)
    os.remove(filepath)
    os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, parser_dir))
Ejemplo n.º 22
0
def download_training_data():
    print('downloading training data...')
    if not os.path.isdir("data"):
        os.makedirs('data')
    if (not os.path.exists('data/train.rec')) or \
            (not os.path.exists('data/test.rec')) or \
            (not os.path.exists('data/train.lst')) or \
            (not os.path.exists('data/test.lst')):
        zip_file_path = download('http://data.mxnet.io/mxnet/data/cifar10.zip')
        with zipfile.ZipFile(zip_file_path) as zf:
            zf.extractall()
        os.rename('cifar', 'data')
    print('done')
def download_training_data():
    print('downloading training data...')
    if not os.path.isdir("data"):
        os.makedirs('data')
    if (not os.path.exists('data/train.rec')) or \
            (not os.path.exists('data/test.rec')) or \
            (not os.path.exists('data/train.lst')) or \
            (not os.path.exists('data/test.lst')):
        zip_file_path = download('http://data.mxnet.io/mxnet/data/cifar10.zip')
        with zipfile.ZipFile(zip_file_path) as zf:
            zf.extractall()
        os.rename('cifar', 'data')
    print('done')
def download_training_data(fileName):
    print('downloading training data...')
    if not os.path.isdir("data"):
        os.makedirs('data')
    if (not os.path.exists('data/train.rec')) or \
            (not os.path.exists('data/test.rec')) or \
            (not os.path.exists('data/train.lst')) or \
            (not os.path.exists('data/test.lst')):
        zip_file_path = download('https://sagemaker-crops-corn.s3.amazonaws.com/' + str(fileName)) #'http://data.mxnet.io/mxnet/data/cifar10.zip')
        with zipfile.ZipFile(zip_file_path) as zf:
            zf.extractall()
        os.rename('cifar', 'data')
    print('done')
Ejemplo n.º 25
0
def get_cifar10(dir="data"):
    """Downloads CIFAR10 dataset into a directory in the current directory with the name `data`,
    and then extracts all files into the directory `data/cifar`.
    """
    if not os.path.isdir(dir):
        os.makedirs(dir)
    if (not os.path.exists(os.path.join(dir, 'cifar', 'train.rec'))) or \
            (not os.path.exists(os.path.join(dir, 'cifar', 'test.rec'))) or \
            (not os.path.exists(os.path.join(dir, 'cifar', 'train.lst'))) or \
            (not os.path.exists(os.path.join(dir, 'cifar', 'test.lst'))):
        zip_file_path = download('http://data.mxnet.io/mxnet/data/cifar10.zip',
                                 dirname=dir)
        with zipfile.ZipFile(zip_file_path) as zf:
            zf.extractall(dir)
Ejemplo n.º 26
0
def download_parser(dirpath):
    parser_dir = 'stanford-parser'
    if os.path.exists(os.path.join(dirpath, parser_dir)):
        print('Found Stanford Parser - skip')
        return
    url = 'http://nlp.stanford.edu/software/stanford-parser-full-2015-01-29.zip'
    filepath = download(url, dirname=dirpath)
    zip_dir = ''
    with zipfile.ZipFile(filepath) as zf:
        zip_dir = zf.namelist()[0]
        zf.extractall(dirpath)
    os.remove(filepath)
    os.rename(os.path.join(dirpath, zip_dir),
              os.path.join(dirpath, parser_dir))
def download_training_data():
    print("downloading training data...")
    if not os.path.isdir("data"):
        os.makedirs("data")
    if (
        (not os.path.exists("data/train.rec"))
        or (not os.path.exists("data/test.rec"))
        or (not os.path.exists("data/train.lst"))
        or (not os.path.exists("data/test.lst"))
    ):
        zip_file_path = download("http://data.mxnet.io/mxnet/data/cifar10.zip")
        with zipfile.ZipFile(zip_file_path) as zf:
            zf.extractall()
        os.rename("cifar", "data")
    print("done")
def download_voc(path, overwrite=False):
    _DOWNLOAD_URLS = [
        ('http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
         '34ed68851bce2a36e2a223fa52c661d592c66b3c'),
        ('http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar',
         '41a8d6e12baa5ab18ee7f8f8029b9e11805b4ef1'),
        ('http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
         '4e443f8a2eca6b1dac8a6c57641b67dd40621a49')
    ]
    makedirs(path)
    for url, checksum in _DOWNLOAD_URLS:
        #         try:
        #             filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum)
        #         except:
        filename = download(url, dirname=path)
        # extract
        with tarfile.open(filename) as tar:
            tar.extractall(path=path)
Ejemplo n.º 29
0
    matched = sha1.hexdigest() == sha1hash
    if not matched:
        print('Found hash mismatch in file {}, possibly due to incomplete download.'.format(file_path))
    return matched

url_format = 'https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/dataset/pikachu/{}'
hashes = {'train.rec': 'e6bcb6ffba1ac04ff8a9b1115e650af56ee969c8',
          'train.idx': 'dcf7318b2602c06428b9988470c731621716c393',
          'val.rec': 'd6c33f799b4d058e82f2cb5bd9a976f69d72d520'}
for k, v in hashes.items():
    fname = 'pikachu_' + k
    target = osp.join('data', fname)
    url = url_format.format(k)
    if not osp.exists(target) or not verified(target, v):
        print('Downloading', target, url)
        download(url, fname=fname, dirname='data', overwrite=True)


import mxnet.image as image
data_shape = 256
batch_size = 32
def get_iterators(data_shape, batch_size):
    class_names = ['pikachu']
    num_class = len(class_names)
    train_iter = image.ImageDetIter(
        batch_size=batch_size,
        data_shape=(3, data_shape, data_shape),
        path_imgrec='./data/pikachu_train.rec',
        path_imgidx='./data/pikachu_train.idx',
        shuffle=True,
        mean=True,
Ejemplo n.º 30
0
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import os
import zipfile
import shutil
from mxnet.test_utils import download

zip_file_path = 'models/msgnet_21styles.zip'
download('https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/models/msgnet_21styles-2cb88353.zip', zip_file_path)

with zipfile.ZipFile(zip_file_path) as zf:
        zf.extractall()

os.remove(zip_file_path)

shutil.move('msgnet_21styles-2cb88353.params', 'models/21styles.params')
Ejemplo n.º 31
0
def get_dataset(prefetch=False):
    """Download the BSDS500 dataset and return train and test iters."""

    if path.exists(data_dir):
        print(
            "Directory {} already exists, skipping.\n"
            "To force download and extraction, delete the directory and re-run."
            "".format(data_dir),
            file=sys.stderr,
        )
    else:
        print("Downloading dataset...", file=sys.stderr)
        downloaded_file = download(dataset_url, dirname=datasets_tmpdir)
        print("done", file=sys.stderr)

        print("Extracting files...", end="", file=sys.stderr)
        os.makedirs(data_dir)
        os.makedirs(tmp_dir)
        with zipfile.ZipFile(downloaded_file) as archive:
            archive.extractall(tmp_dir)
        shutil.rmtree(datasets_tmpdir)

        shutil.copytree(
            path.join(tmp_dir, "BSDS500-master", "BSDS500", "data", "images"),
            path.join(data_dir, "images"),
        )
        shutil.copytree(
            path.join(tmp_dir, "BSDS500-master", "BSDS500", "data", "groundTruth"),
            path.join(data_dir, "groundTruth"),
        )
        shutil.rmtree(tmp_dir)
        print("done", file=sys.stderr)

    crop_size = 256
    crop_size -= crop_size % upscale_factor
    input_crop_size = crop_size // upscale_factor

    input_transform = [CenterCropAug((crop_size, crop_size)), ResizeAug(input_crop_size)]
    target_transform = [CenterCropAug((crop_size, crop_size))]

    iters = (
        ImagePairIter(
            path.join(data_dir, "images", "train"),
            (input_crop_size, input_crop_size),
            (crop_size, crop_size),
            batch_size,
            color_flag,
            input_transform,
            target_transform,
        ),
        ImagePairIter(
            path.join(data_dir, "images", "test"),
            (input_crop_size, input_crop_size),
            (crop_size, crop_size),
            test_batch_size,
            color_flag,
            input_transform,
            target_transform,
        ),
    )

    return [PrefetchingIter(i) for i in iters] if prefetch else iters
Ejemplo n.º 32
0
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Testing super_resolution model conversion"""
from __future__ import absolute_import as _abs
from __future__ import print_function
from collections import namedtuple
import mxnet as mx
from mxnet.test_utils import download
import numpy as np
from PIL import Image
import onnx_mxnet

model_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/super_resolution.onnx'

download(model_url, 'super_resolution.onnx')

print("Converting onnx format to mxnet's symbol and params...")
sym, params = onnx_mxnet.import_model('super_resolution.onnx')

# Load test image
input_image_dim = 224
output_image_dim = 672
img_url = 'https://s3.amazonaws.com/onnx-mxnet/examples/super_res_input.jpg'
download(img_url, 'super_res_input.jpg')
img = Image.open('super_res_input.jpg').resize(
    (input_image_dim, input_image_dim))
img_ycbcr = img.convert("YCbCr")
img_y, img_cb, img_cr = img_ycbcr.split()
x = np.array(img_y)[np.newaxis, np.newaxis, :, :]
Ejemplo n.º 33
0
def get_dataset(prefetch=False):
    """Download the BSDS500 dataset and return train and test iters."""

    if path.exists(data_dir):
        print(
            "Directory {} already exists, skipping.\n"
            "To force download and extraction, delete the directory and re-run."
            "".format(data_dir),
            file=sys.stderr,
        )
    else:
        print("Downloading dataset...", file=sys.stderr)
        downloaded_file = download(dataset_url, dirname=datasets_tmpdir)
        print("done", file=sys.stderr)

        print("Extracting files...", end="", file=sys.stderr)
        os.makedirs(data_dir)
        os.makedirs(tmp_dir)
        with zipfile.ZipFile(downloaded_file) as archive:
            archive.extractall(tmp_dir)
        shutil.rmtree(datasets_tmpdir)

        shutil.copytree(
            path.join(tmp_dir, "BSDS500-master", "BSDS500", "data", "images"),
            path.join(data_dir, "images"),
        )
        shutil.copytree(
            path.join(tmp_dir, "BSDS500-master", "BSDS500", "data",
                      "groundTruth"),
            path.join(data_dir, "groundTruth"),
        )
        shutil.rmtree(tmp_dir)
        print("done", file=sys.stderr)

    crop_size = 256
    crop_size -= crop_size % upscale_factor
    input_crop_size = crop_size // upscale_factor

    input_transform = [
        CenterCropAug((crop_size, crop_size)),
        ResizeAug(input_crop_size)
    ]
    target_transform = [CenterCropAug((crop_size, crop_size))]

    iters = (
        ImagePairIter(
            path.join(data_dir, "images", "train"),
            (input_crop_size, input_crop_size),
            (crop_size, crop_size),
            batch_size,
            color_flag,
            input_transform,
            target_transform,
        ),
        ImagePairIter(
            path.join(data_dir, "images", "test"),
            (input_crop_size, input_crop_size),
            (crop_size, crop_size),
            test_batch_size,
            color_flag,
            input_transform,
            target_transform,
        ),
    )

    return [PrefetchingIter(i) for i in iters] if prefetch else iters
Ejemplo n.º 34
0
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import os
import zipfile
import shutil
from mxnet.test_utils import download

zip_file_path = 'models/msgnet_21styles.zip'
download(
    'https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/models/msgnet_21styles-2cb88353.zip',
    zip_file_path)

with zipfile.ZipFile(zip_file_path) as zf:
    zf.extractall()

os.remove(zip_file_path)

shutil.move('msgnet_21styles-2cb88353.params', 'models/21styles.params')
Ejemplo n.º 35
0
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import os, zipfile
import mxnet
from mxnet.test_utils import download

def unzip_file(filename, outpath):
    fh = open(filename, 'rb')
    z = zipfile.ZipFile(fh)
    for name in z.namelist():
        z.extract(name, outpath)
    fh.close()

# Dataset from COCO 2014: http://cocodataset.org/#download
# The dataset annotations and site are Copyright COCO Consortium and licensed CC BY 4.0 Attribution.
# The images within the dataset are available under the Flickr Terms of Use.
# See http://cocodataset.org/#termsofuse for details
download('http://msvocds.blob.core.windows.net/coco2014/train2014.zip', 'dataset/train2014.zip')
download('http://msvocds.blob.core.windows.net/coco2014/val2014.zip', 'dataset/val2014.zip')

unzip_file('dataset/train2014.zip', 'dataset')
unzip_file('dataset/val2014.zip', 'dataset')
Ejemplo n.º 36
0
#!/usr/bin/env python
#-*- coding:utf-8 -*-
from mxnet.test_utils import download
import os.path as osp
def verified(file_path, sha1hash):
    import hashlib
    sha1 = hashlib.sha1()
    with open(file_path, 'rb') as f:
        while True:
            data = f.read(1048576)
            if not data:
                break
            sha1.update(data)
    matched = sha1.hexdigest() == sha1hash
    if not matched:
        print('Found hash mismatch in file {}, possibly due to incomplete download.'.format(file_path))
    return matched

url_format = 'https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/datasets/pikachu/{}'
hashes = {'train.rec': 'e6bcb6ffba1ac04ff8a9b1115e650af56ee969c8',
          'train.idx': 'dcf7318b2602c06428b9988470c731621716c393',
          'val.rec': 'd6c33f799b4d058e82f2cb5bd9a976f69d72d520'}
for k, v in hashes.items():##键 值
    fname = 'pikachu_' + k
    target = osp.join('data', fname)#新建数据集文件夹
    url = url_format.format(k)
    if not osp.exists(target) or not verified(target, v):
        print('Downloading', target, url)
        download(url, fname=fname, dirname='data', overwrite=True)
Ejemplo n.º 37
0
def GetMNIST_pkl():
    if not os.path.isdir("data"):
        os.makedirs('data')
    if not os.path.exists('data/mnist.pkl.gz'):
        download('http://deeplearning.net/data/mnist/mnist.pkl.gz',
                 dirname='data')
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from mxnet.test_utils import download

download(
    'https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/models/21styles-32f7205c.params',
    'models/21styles.params')
Ejemplo n.º 39
0
def GetMNIST_pkl():
    if not os.path.isdir("data"):
        os.makedirs('data')
    if not os.path.exists('data/mnist.pkl.gz'):
        download('http://deeplearning.net/data/mnist/mnist.pkl.gz',
                 dirname='data')
Ejemplo n.º 40
0
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import os, zipfile
import mxnet
from mxnet.test_utils import download

def unzip_file(filename, outpath):
    fh = open(filename, 'rb')
    z = zipfile.ZipFile(fh)
    for name in z.namelist():
        z.extract(name, outpath)
    fh.close()

download('http://msvocds.blob.core.windows.net/coco2014/train2014.zip', 'dataset/train2014.zip')
download('http://msvocds.blob.core.windows.net/coco2014/val2014.zip', 'dataset/val2014.zip')

unzip_file('dataset/train2014.zip', 'dataset')
unzip_file('dataset/val2014.zip', 'dataset')
Ejemplo n.º 41
0
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from mxnet.test_utils import download

download('https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/models/21styles-32f7205c.params', 'models/21styles.params')