def download(self):
        """
        Process MNIST1M data if it does not exist in processed_folder already.
        """

        # check if processed data does not exist:
        if self._check_exists():
            return

        # process and save as torch files:
        logging.info("Processing MNIST1M data...")
        os.makedirs(self.processed_folder, exist_ok=True)
        training_set = (read_image_file(
            os.path.join(self.raw_folder, "mnist1m-images-idx3-ubyte")),
                        read_label_file(
                            os.path.join(self.raw_folder,
                                         "mnist1m-labels-idx1-ubyte")))
        test_set = (read_image_file(
            os.path.join(self.raw_folder, "t10k-images-idx3-ubyte")),
                    read_label_file(
                        os.path.join(self.raw_folder,
                                     "t10k-labels-idx1-ubyte")))
        with open(os.path.join(self.processed_folder, self.training_file),
                  "wb") as f:
            torch.save(training_set, f)
        with open(os.path.join(self.processed_folder, self.test_file),
                  "wb") as f:
            torch.save(test_set, f)
        logging.info("Done!")
Пример #2
0
    def load(self, folder):
        if self._check_exists():
            return
        os.makedirs(self.raw_folder, exist_ok=True)
        os.makedirs(self.processed_folder, exist_ok=True)

        for url, md5 in self.resources:
            filename = url.rpartition('/')[2]
            load_and_extract_archive(source_root=folder,
                                     target_root=self.raw_folder,
                                     filename=filename,
                                     md5=md5)  # NOTICE

        # process and save as torch files
        print('Processing...')

        training_set = (read_image_file(
            os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),
                        read_label_file(
                            os.path.join(self.raw_folder,
                                         'train-labels-idx1-ubyte')))
        test_set = (read_image_file(
            os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),
                    read_label_file(
                        os.path.join(self.raw_folder,
                                     't10k-labels-idx1-ubyte')))
        with open(os.path.join(self.processed_folder, self.training_file),
                  'wb') as f:
            torch.save(training_set, f)
        with open(os.path.join(self.processed_folder, self.test_file),
                  'wb') as f:
            torch.save(test_set, f)

        print('Done!')
Пример #3
0
def data_loader(root_path, train=True):
    if train:
        data_dict = {
            'train': {
                'data':
                read_image_file(
                    os.path.join(root_path, 'train',
                                 'train-images-idx3-ubyte')),
                'label':
                read_label_file(
                    os.path.join(root_path, 'train',
                                 'train-labels-idx1-ubyte'))
            },
            'test': {
                'data':
                read_image_file(
                    os.path.join(root_path, 'test', 't10k-images-idx3-ubyte')),
                'label':
                read_label_file(
                    os.path.join(root_path, 'test', 't10k-labels-idx1-ubyte'))
            }
        }
    else:
        data_dict = {
            'data':
            read_image_file(
                os.path.join(root_path, 'test', 't10k-images-idx3-ubyte')),
            'label':
            None
        }
    return data_dict
    def download(self):
        """Download the MNIST data if it doesn't exist in processed_folder already."""
        from six.moves import urllib
        import gzip

        if self._check_exists():
            return

        # download files
        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
            os.makedirs(os.path.join(self.root, self.processed_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        for url in self.urls:
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            download_url(url,
                         root=os.path.join(self.root, self.raw_folder),
                         filename=filename,
                         md5=None)
            with open(file_path.replace('.gz', ''), 'wb') as \
                out_f, gzip.GzipFile(file_path) as zip_f:
                out_f.write(zip_f.read())
            os.unlink(file_path)

        # process and save as torch files
        print('Processing...')

        training_set = (
          read_image_file(os.path.join(self.root, self.raw_folder, \
            'train-images-idx3-ubyte')),
          read_label_file(os.path.join(self.root, self.raw_folder, \
            'train-labels-idx1-ubyte'))
        )
        test_set = (
          read_image_file(os.path.join(self.root, self.raw_folder, \
            't10k-images-idx3-ubyte')),
          read_label_file(os.path.join(self.root, self.raw_folder, \
            't10k-labels-idx1-ubyte'))
        )
        with open(os.path.join(self.root, self.processed_folder, \
            self.training_file), 'wb') as f:
            torch.save(training_set, f)
        with open(os.path.join(self.root, self.processed_folder, \
            self.test_file), 'wb') as f:
            torch.save(test_set, f)
        print('Done!')
Пример #5
0
 def process(self):
     print('Processing...')
     if self.train:
         training_set = (read_image_file(
             os.path.join(self.root, 'train-images-idx3-ubyte')),
                         read_label_file(
                             os.path.join(self.root,
                                          'train-labels-idx1-ubyte')))
         print('Done!')
         return training_set
     else:
         test_set = (read_image_file(
             os.path.join(self.root, 't10k-images-idx3-ubyte')),
                     read_label_file(
                         os.path.join(self.root, 't10k-labels-idx1-ubyte')))
         print('Done!')
         return test_set
Пример #6
0
def patched_download(self):
    """wget patched download method.
    """
    if self._check_exists():
        return

    os.makedirs(self.raw_folder, exist_ok=True)
    os.makedirs(self.processed_folder, exist_ok=True)

    # download files
    for url, md5 in self.resources:
        filename = url.rpartition("/")[2]
        download_root = os.path.expanduser(self.raw_folder)
        extract_root = None
        remove_finished = False

        if extract_root is None:
            extract_root = download_root
        if not filename:
            filename = os.path.basename(url)

        # Use wget to download archives
        sp.run(["wget", url, "-P", download_root])

        archive = os.path.join(download_root, filename)
        print("Extracting {} to {}".format(archive, extract_root))
        extract_archive(archive, extract_root, remove_finished)

    # process and save as torch files
    print("Processing...")

    training_set = (
        read_image_file(os.path.join(self.raw_folder, "train-images-idx3-ubyte")),
        read_label_file(os.path.join(self.raw_folder, "train-labels-idx1-ubyte")),
    )
    test_set = (
        read_image_file(os.path.join(self.raw_folder, "t10k-images-idx3-ubyte")),
        read_label_file(os.path.join(self.raw_folder, "t10k-labels-idx1-ubyte")),
    )
    with open(os.path.join(self.processed_folder, self.training_file), "wb") as f:
        torch.save(training_set, f)
    with open(os.path.join(self.processed_folder, self.test_file), "wb") as f:
        torch.save(test_set, f)

    print("Done!")
Пример #7
0
def main():
    root = "mnist/MNIST/raw/"

    train_set = (mnist.read_image_file(
        os.path.join(root, 'train-images-idx3-ubyte')),
                 mnist.read_label_file(
                     os.path.join(root, 'train-labels-idx1-ubyte')))

    test_set = (mnist.read_image_file(
        os.path.join(root, 't10k-images-idx3-ubyte')),
                mnist.read_label_file(
                    os.path.join(root, 't10k-labels-idx1-ubyte')))

    print("train set:", train_set[0].size())
    print("test set:", test_set[0].size())

    def convert_to_img(train=True):
        if (train):
            f = open(root + 'train.txt', 'w')
            data_path = root + '/train/'
            if (not os.path.exists(data_path)):
                os.makedirs(data_path)
            for i, (img, label) in enumerate(zip(train_set[0], train_set[1])):
                img_path = data_path + str(i) + '.jpg'
                io.imsave(img_path, img.numpy())
                int_label = str(label).replace('tensor(', '')
                int_label = int_label.replace(')', '')
                f.write(img_path + ' ' + str(int_label) + '\n')
            f.close()
        else:
            f = open(root + 'test.txt', 'w')
            data_path = root + '/test/'
            if (not os.path.exists(data_path)):
                os.makedirs(data_path)
            for i, (img, label) in enumerate(zip(test_set[0], test_set[1])):
                img_path = data_path + str(i) + '.jpg'
                io.imsave(img_path, img.numpy())
                int_label = str(label).replace('tensor(', '')
                int_label = int_label.replace(')', '')
                f.write(img_path + ' ' + str(int_label) + '\n')
            f.close()

    convert_to_img(True)
    convert_to_img(False)
Пример #8
0
def init_private_dataset(args):
    if args.private_dataset == 'FEMNIST':
        rootraw = "Data/FEMNIST/raw/"
        root = "Data/FEMNIST/"
    train_set = (
        mnist.read_image_file(os.path.join(rootraw, 'emnist-letters-train-images-idx3-ubyte')),
        mnist.read_label_file(os.path.join(rootraw, 'emnist-letters-train-labels-idx1-ubyte'))
    )

    test_set = (
        mnist.read_image_file(os.path.join(rootraw, 'emnist-letters-test-images-idx3-ubyte')),
        mnist.read_label_file(os.path.join(rootraw, 'emnist-letters-test-labels-idx1-ubyte'))
    )

    print("train set:", train_set[0].size())
    print("test set:", test_set[0].size())

    convert_to_img(root,train_set,test_set,train=True)
    convert_to_img(root,train_set,test_set,train=False)
Пример #9
0
def fashionDatasetPrep(root):
    train_set = (mnist.read_image_file(
        os.path.join(root, 'train-images-idx3-ubyte')),
                 mnist.read_label_file(
                     os.path.join(root, 'train-labels-idx1-ubyte')))
    eval_set = (mnist.read_image_file(
        os.path.join(root, 't10k-images-idx3-ubyte')),
                mnist.read_label_file(
                    os.path.join(root, 't10k-labels-idx1-ubyte')))
    print("train_set :", train_set[0].size())
    print("eval_set :", eval_set[0].size())
    for mode in ('train', 'eval'):
        data_path = os.path.join(root, mode)
        with open(os.path.join(root, mode + '.csv'), 'w') as f:
            if (not os.path.exists(data_path)):
                os.makedirs(data_path)
            for i, (img, label) in enumerate(zip(eval_set[0], eval_set[1])):
                img_path = os.path.join(data_path, str(i) + '.jpg')
                io.imsave(img_path, img.numpy())
                f.write(img_path + ' ' + str(label.item()) + '\n')
Пример #10
0
def data_loader(root_path):
    root_path = os.path.join(root_path, 'train')
    data_dict = {
        'train': {
            'data':
            read_image_file(
                os.path.join(root_path, 'train', 'train-images-idx3-ubyte')),
            'label':
            read_label_file(
                os.path.join(root_path, 'train', 'train-labels-idx1-ubyte'))
        },
        'test': {
            'data':
            read_image_file(
                os.path.join(root_path, 'test', 't10k-images-idx3-ubyte')),
            'label':
            read_label_file(
                os.path.join(root_path, 'test', 't10k-labels-idx1-ubyte'))
        }
    }
    return data_dict
Пример #11
0
    def download(self):
        """Download the MNIST data if it doesn't exist in processed_folder already."""

        if self._check_exists():
            return

        makedir_exist_ok(self.raw_folder)
        makedir_exist_ok(self.processed_folder)

        # download files
        for url, md5 in self.resources:
            filename = url.rpartition('/')[2]
            download_and_extract_archive(url,
                                         download_root=self.raw_folder,
                                         filename=filename,
                                         md5=md5)

        # process and save as torch files
        print('Processing...')

        training_set = (read_image_file(
            os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),
                        read_label_file(
                            os.path.join(self.raw_folder,
                                         'train-labels-idx1-ubyte')))
        test_set = (read_image_file(
            os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),
                    read_label_file(
                        os.path.join(self.raw_folder,
                                     't10k-labels-idx1-ubyte')))
        with open(os.path.join(self.processed_folder, self.training_file),
                  'wb') as f:
            torch.save(training_set, f)
        with open(os.path.join(self.processed_folder, self.test_file),
                  'wb') as f:
            torch.save(test_set, f)

        print('Done!')
Пример #12
0
def mnist_to_image(path, train=True):
    '''
    path: 数据集二进制文件保存地址
    功能:将mnist数据集另存为图片,并生成txt文件
    '''
    #root = os.path.expanduser(path)
    train_set = (mnist.read_image_file(
        os.path.join(path, 'train-images-idx3-ubyte')),
                 mnist.read_label_file(
                     os.path.join(path, 'train-labels-idx1-ubyte')))
    test_set = (mnist.read_image_file(
        os.path.join(path, 't10k-images-idx3-ubyte')),
                mnist.read_label_file(
                    os.path.join(path, 't10k-labels-idx1-ubyte')))
    print("training set :", train_set[0].size())
    print("test set :", test_set[0].size())
    if (train):
        f = open(path + 'dir_file/train.txt', 'w')
        data_path = path + 'mnist_train/'
        if (not os.path.exists(data_path)):
            os.makedirs(data_path)
        for i, (img, label) in enumerate(zip(train_set[0], train_set[1])):
            img_path = data_path + str(i) + '.jpg'
            skimage.io.imsave(img_path, img.numpy())
            f.write(img_path + ' ' + str(label.numpy()) + '\n')
        f.close()
    else:
        f = open(path + 'dir_file/test.txt', 'w')
        data_path = path + 'mnist_test/'
        if (not os.path.exists(data_path)):
            os.makedirs(data_path)
        for i, (img, label) in enumerate(zip(test_set[0], test_set[1])):
            img_path = data_path + str(i) + '.jpg'
            skimage.io.imsave(img_path, img.numpy())
            f.write(img_path + ' ' + str(label.numpy()) + '\n')
        f.close()
    def __init__(self):
        """---------------------------------------------------------------------
        Desc.:   Class Constructor  
        Args:    -
        Returns: - 
        ---------------------------------------------------------------------"""

        NeuralNetwork.__init__(self, [28 * 28, 10 * 10, 10])

        print "----------------------------------------------------------------"
        print "Digit Classfier using pytorch nn module"
        print "Author: Ankit Manerikar"
        print "Written on: 09-21-2017"
        print "----------------------------------------------------------------"
        print "Loading MNIST Dataset ..."
        self.train_images = read_image_file(
            './data/raw/train-images-idx3-ubyte')
        self.target_val = read_label_file('./data/raw/train-labels-idx1-ubyte')
        print "Dataset Loaded"
        print "\nClass initialized"
#!/usr/bin/env python3
# -*-coding:utf-8 -*-
import os
from skimage import io
import torchvision.datasets.mnist as mnist
import numpy
 
 
 
 
train_set = (
    mnist.read_image_file('./train-images.idx3-ubyte'),
    mnist.read_label_file('./train-labels.idx1-ubyte')
)
 
test_set = (
    mnist.read_image_file('./t10k-images.idx3-ubyte'),
    mnist.read_label_file('./t10k-labels.idx1-ubyte')
)
 
print("train set:", train_set[0].size())
print("test set:", test_set[0].size())
 
 
def convert_to_img(train=True):
    if(train):
        # f = open(root + 'train.txt', 'w')
        # data_path = './train/'
        #if(not os.path.exists(data_path)):
        #    os.makedirs(data_path)
        for i, (img, label) in enumerate(zip(train_set[0], train_set[1])):
Пример #15
0
import os, skimage, cv2
import numpy as np
import torch
import torch.nn as nn
import torchvision.datasets.mnist as mnist

os.system("OPENCV_PYTHON")

path_file = os.path.dirname(os.path.realpath(__file__))
root = os.path.join(path_file, "datum")

# train data
train_set = (mnist.read_image_file(
    os.path.join(root, "train-images-idx3-ubyte")),
             mnist.read_label_file(
                 os.path.join(root, "train-labels-idx1-ubyte")))

# test data
test_set = (mnist.read_image_file(os.path.join(root,
                                               "t10k-images-idx3-ubyte")),
            mnist.read_label_file(os.path.join(root,
                                               "t10k-labels-idx1-ubyte")))

print(">>> train set: ", train_set[0].size())
print(">>> test set: ", test_set[0].size())


def convert_to_image(train=True):
    if train:
        f = open(os.path.join(root, "train.txt"), "w")
        image_path = os.path.join(root, "train")
Пример #16
0
    Args:
        root (string): Root directory of dataset where ``MNIST/processed/training.pt``
            and  ``MNIST/processed/test.pt`` exist.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
import os
from skimage import io
import torchvision.datasets.mnist as mnist
import struct
root = "../"
train_set = (
    # mnist.read_image_file('D:\\PycharmProject\\LeNet-5\\data\\MNIST\\t10k-images-idx3-ubyte.gz'),
    mnist.read_label_file(os.path.join(root, 'train-labels.idx1-ubyte')),
    # mnist.read_image_file(os.path.join(root, 'train-images-idx3-ubyte.gz')),
    mnist.read_label_file(os.path.join(root, 'train-labels.idx1-ubyte')))
test_set = (mnist.read_image_file(os.path.join(root,
                                               't10k-images.idx3-ubyte')),
            mnist.read_label_file(os.path.join(root,
                                               't10k-labels.idx1-ubyte')))
print("training set :", train_set[0].size())
print("test set :", test_set[0].size())
Пример #17
0
    def download(self):
        """Download the EMNIST data if it doesn't exist in processed_folder already."""
        import errno
        from six.moves import urllib
        import gzip
        import shutil
        import zipfile

        if self._check_exists():
            return

        # download files
        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
            os.makedirs(os.path.join(self.root, self.processed_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        print('Downloading ' + self.url)
        data = urllib.request.urlopen(self.url)
        filename = self.url.rpartition('/')[2]
        raw_folder = os.path.join(self.root, self.raw_folder)
        file_path = os.path.join(raw_folder, filename)
        with open(file_path, 'wb') as f:
            f.write(data.read())

        print('Extracting zip archive')
        with zipfile.ZipFile(file_path) as zip_f:
            zip_f.extractall(raw_folder)
        os.unlink(file_path)
        gzip_folder = os.path.join(raw_folder, 'gzip')
        for gzip_file in os.listdir(gzip_folder):
            if gzip_file.endswith('.gz'):
                print('Extracting ' + gzip_file)
                with open(os.path.join(raw_folder, gzip_file.replace('.gz', '')), 'wb') as out_f, \
                        gzip.GzipFile(os.path.join(gzip_folder, gzip_file)) as zip_f:
                    out_f.write(zip_f.read())
        shutil.rmtree(gzip_folder)

        # process and save as torch files
        for split in self.splits:
            print('Processing ' + split)
            training_set = (
                read_image_file(
                    os.path.join(
                        raw_folder,
                        'emnist-{}-train-images-idx3-ubyte'.format(split))),
                read_label_file(
                    os.path.join(
                        raw_folder,
                        'emnist-{}-train-labels-idx1-ubyte'.format(split))))
            test_set = (
                read_image_file(
                    os.path.join(
                        raw_folder,
                        'emnist-{}-test-images-idx3-ubyte'.format(split))),
                read_label_file(
                    os.path.join(
                        raw_folder,
                        'emnist-{}-test-labels-idx1-ubyte'.format(split))))
            with open(
                    os.path.join(self.root, self.processed_folder,
                                 self._training_file(split)), 'wb') as f:
                torch.save(training_set, f)
            with open(
                    os.path.join(self.root, self.processed_folder,
                                 self._test_file(split)), 'wb') as f:
                torch.save(test_set, f)

        print('Done!')
 def get_test_set(self):
     test_set = (mnist.read_image_file(
         os.path.join(self.root, 'test-images-idx3-ubyte')),
                 mnist.read_label_file(
                     os.path.join(self.root, 'test-labels-idx1-ubyte')))
     return test_set
def test_to_superpixels():
    root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))

    raw_folder = osp.join(root, 'MNIST', 'raw')
    processed_folder = osp.join(root, 'MNIST', 'processed')

    makedirs(raw_folder)
    makedirs(processed_folder)
    for resource in resources:
        path = download_url(resource, raw_folder)
        extract_gz(path, osp.join(root, raw_folder))

    test_set = (
        read_image_file(osp.join(raw_folder, 't10k-images-idx3-ubyte')),
        read_label_file(osp.join(raw_folder, 't10k-labels-idx1-ubyte')),
    )

    torch.save(test_set, osp.join(processed_folder, 'training.pt'))
    torch.save(test_set, osp.join(processed_folder, 'test.pt'))

    dataset = MNIST(root, download=False)

    dataset.transform = T.Compose([T.ToTensor(), ToSLIC()])

    data, y = dataset[0]
    assert len(data) == 2
    assert data.pos.dim() == 2 and data.pos.size(1) == 2
    assert data.x.dim() == 2 and data.x.size(1) == 1
    assert data.pos.size(0) == data.x.size(0)
    assert y == 7

    loader = DataLoader(dataset, batch_size=2, shuffle=False)
    for data, y in loader:
        assert len(data) == 4
        assert data.pos.dim() == 2 and data.pos.size(1) == 2
        assert data.x.dim() == 2 and data.x.size(1) == 1
        assert data.batch.dim() == 1
        assert data.ptr.dim() == 1
        assert data.pos.size(0) == data.x.size(0) == data.batch.size(0)
        assert y.tolist() == [7, 2]
        break

    dataset.transform = T.Compose(
        [T.ToTensor(), ToSLIC(add_seg=True, add_img=True)])

    data, y = dataset[0]
    assert len(data) == 4
    assert data.pos.dim() == 2 and data.pos.size(1) == 2
    assert data.x.dim() == 2 and data.x.size(1) == 1
    assert data.pos.size(0) == data.x.size(0)
    assert data.seg.size() == (1, 28, 28)
    assert data.img.size() == (1, 1, 28, 28)
    assert data.seg.max().item() + 1 == data.x.size(0)
    assert y == 7

    loader = DataLoader(dataset, batch_size=2, shuffle=False)
    for data, y in loader:
        assert len(data) == 6
        assert data.pos.dim() == 2 and data.pos.size(1) == 2
        assert data.x.dim() == 2 and data.x.size(1) == 1
        assert data.batch.dim() == 1
        assert data.ptr.dim() == 1
        assert data.pos.size(0) == data.x.size(0) == data.batch.size(0)
        assert data.seg.size() == (2, 28, 28)
        assert data.img.size() == (2, 1, 28, 28)
        assert y.tolist() == [7, 2]
        break

    shutil.rmtree(root)
import os
from skimage import io
import torchvision.datasets.mnist as mnist
import numpy

readFrom = 'mnist/MNIST/raw/'
writeTo = '../mnistImgs/'
if(not os.path.exists(writeTo)):
    os.makedirs(writeTo)

train_set = (
    mnist.read_image_file(os.path.join(readFrom, 'train-images-idx3-ubyte')),
    mnist.read_label_file(os.path.join(readFrom, 'train-labels-idx1-ubyte'))
)

test_set = (
    mnist.read_image_file(os.path.join(readFrom,'t10k-images-idx3-ubyte')),
    mnist.read_label_file(os.path.join(readFrom,'t10k-labels-idx1-ubyte'))
)
 
print("train set:", train_set[0].size())
print("test set:", test_set[0].size())

def convert_to_img(train=True):
    if(train):
        f = open(writeTo + 'train.txt', 'w')
        data_path = writeTo + 'train/'
        if(not os.path.exists(data_path)):
            os.makedirs(data_path)
        for i, (img, label) in enumerate(zip(train_set[0], train_set[1])):
            img_path = data_path + str(i) + '.jpg'
def main():
    parser = get_parser()
    args = parser.parse_args()
    ######################################################################
    # Loading the data
    # ----------------
    #
    # In this post we experiment with the classic MNIST dataset. Using a
    # standard convolutional network augmented with a spatial transformer
    # network.
    data_dir = args.data_dir
    training_dir = args.training_dir

    # If already processed
    if (not path.exists(
            path.join(training_dir, mnist.MNIST.processed_folder,
                      mnist.MNIST.training_file))
            or not path.join(training_dir, mnist.MNIST.processed_folder,
                             mnist.MNIST.test_file)):
        # process and save as torch files
        LOG.info('Processing dataset...')

        files = os.listdir(data_dir)
        for file in files:
            full_path = path.join(data_dir, file)
            save_path = path.join(training_dir, file.replace('.gz', ''))
            with open(save_path,
                      'wb') as out_f, gzip.GzipFile(full_path) as zip_f:
                out_f.write(zip_f.read())

        training_set = (mnist.read_image_file(
            path.join(training_dir, 'train-images-idx3-ubyte')),
                        mnist.read_label_file(
                            path.join(training_dir,
                                      'train-labels-idx1-ubyte')))
        test_set = (mnist.read_image_file(
            path.join(training_dir, 't10k-images-idx3-ubyte')),
                    mnist.read_label_file(
                        path.join(training_dir, 't10k-labels-idx1-ubyte')))
        os.makedirs(path.join(training_dir, mnist.MNIST.processed_folder))
        with open(
                path.join(training_dir, mnist.MNIST.processed_folder,
                          mnist.MNIST.training_file), 'wb') as f:
            torch.save(training_set, f)
        with open(
                path.join(training_dir, mnist.MNIST.processed_folder,
                          mnist.MNIST.test_file), 'wb') as f:
            torch.save(test_set, f)

        LOG.info('Dataset processing done!')

    # Training dataset
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        root=training_dir,
        train=True,
        download=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4)
    # Test dataset
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        root=training_dir,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=4)

    model = Net()
    if use_cuda:
        model.cuda()

    ######################################################################
    # Training the model
    # ------------------
    #
    # Now, let's use the SGD algorithm to train the model. The network is
    # learning the classification task in a supervised way. In the same time
    # the model is learning STN automatically in an end-to-end fashion.

    optimizer = optim.SGD(model.parameters(), lr=0.01)

    for epoch in range(1, args.epochs + 1):
        train(epoch,
              train_loader,
              model,
              optimizer,
              use_mlboard=not args.skip_mlboard)
        test(test_loader, model, use_mlboard=not args.skip_mlboard)

    # Visualize the STN transformation on some input batch
    visualize_stn(test_loader, model, args.training_dir)
Пример #22
0
                             mnist.MNIST.test_file)):
        # process and save as torch files
        LOG.info('Processing dataset...')

        files = os.listdir(data_dir)
        for file in files:
            full_path = path.join(data_dir, file)
            save_path = path.join(training_dir, file.replace('.gz', ''))
            with open(save_path,
                      'wb') as out_f, gzip.GzipFile(full_path) as zip_f:
                out_f.write(zip_f.read())

        training_set = (mnist.read_image_file(
            path.join(training_dir, 'train-images-idx3-ubyte')),
                        mnist.read_label_file(
                            path.join(training_dir,
                                      'train-labels-idx1-ubyte')))
        test_set = (mnist.read_image_file(
            path.join(training_dir, 't10k-images-idx3-ubyte')),
                    mnist.read_label_file(
                        path.join(training_dir, 't10k-labels-idx1-ubyte')))
        os.makedirs(path.join(training_dir, mnist.MNIST.processed_folder))
        with open(
                path.join(training_dir, mnist.MNIST.processed_folder,
                          mnist.MNIST.training_file), 'wb') as f:
            torch.save(training_set, f)
        with open(
                path.join(training_dir, mnist.MNIST.processed_folder,
                          mnist.MNIST.test_file), 'wb') as f:
            torch.save(test_set, f)