예제 #1
0
파일: mnist.py 프로젝트: zhoujian1210/fuel
class MNIST(H5PYDataset):
    u"""MNIST dataset.

    MNIST (Mixed National Institute of Standards and Technology) [LBBH] is
    a database of handwritten digits. It is one of the most famous
    datasets in machine learning and consists of 60,000 training images
    and 10,000 testing images. The images are grayscale and 28 x 28 pixels
    large. It is accessible through Yann LeCun's website [LECUN].

    .. [LECUN] http://yann.lecun.com/exdb/mnist/

    Parameters
    ----------
    which_sets : tuple of str
        Which split to load. Valid values are 'train' and 'test',
        corresponding to the training set (60,000 examples) and the test
        set (10,000 examples).

    """
    filename = 'mnist.hdf5'
    default_transformers = uint8_pixels_to_floatX(('features', ))

    def __init__(self, which_sets, **kwargs):
        kwargs.setdefault('load_in_memory', True)
        super(MNIST,
              self).__init__(file_or_path=find_in_data_path(self.filename),
                             which_sets=which_sets,
                             **kwargs)
예제 #2
0
class CIFAR10(H5PYDataset):
    """The CIFAR10 dataset of natural images.

    This dataset is a labeled subset of the ``80 million tiny images``
    dataset [TINY]. It consists of 60,000 32 x 32 colour images in 10
    classes, with 6,000 images per class. There are 50,000 training
    images and 10,000 test images [CIFAR10].

    .. [CIFAR10] Alex Krizhevsky, *Learning Multiple Layers of Features
       from Tiny Images*, technical report, 2009.

    Parameters
    ----------
    which_sets : tuple of str
        Which split to load. Valid values are 'train' and 'test',
        corresponding to the training set (50,000 examples) and the test
        set (10,000 examples). Note that CIFAR10 does not have a
        validation set; usually you will create your own
        training/validation split using the `subset` argument.

    """
    filename = 'cifar10.hdf5'
    default_transformers = uint8_pixels_to_floatX(('features',))

    def __init__(self, which_sets, **kwargs):
        kwargs.setdefault('load_in_memory', True)
        super(CIFAR10, self).__init__(
            file_or_path=find_in_data_path(self.filename),
            which_sets=which_sets, **kwargs)
예제 #3
0
class CIFAR10(H5PYDataset):
    """The CIFAR10 dataset of natural images.

    This dataset is a labeled subset of the ``80 million tiny images''
    dataset [TINY]. It consists of 60,000 32 x 32 colour images in 10
    classes, with 6,000 images per class. There are 50,000 training
    images and 10,000 test images [CIFAR10].

    .. [TINY] Antonio Torralba, Rob Fergus and William T. Freeman,
       *80 million tiny images: a large dataset for non-parametric
       object and scene recognition*, Pattern Analysis and Machine
       Intelligence, IEEE Transactions on 30.11 (2008): 1958-1970.

    .. [CIFAR10] Alex Krizhevsky, *Learning Multiple Layers of Features
       from Tiny Images*, technical report, 2009.

    Parameters
    ----------
    which_sets : tuple of str
        Which split to load. Valid values are 'train' and 'test',
        corresponding to the training set (50,000 examples) and the test
        set (10,000 examples). Note that CIFAR10 does not have a
        validation set; usually you will create your own
        training/validation split using the `subset` argument.

    """
    filename = 'cifar10.hdf5'
    default_transformers = uint8_pixels_to_floatX(('features',))

    def __init__(self, which_sets, **kwargs):
        kwargs.setdefault('load_in_memory', True)
        super(CIFAR10, self).__init__(
            file_or_path=find_in_data_path(self.filename),
            which_sets=which_sets, **kwargs)
예제 #4
0
class Camvid(H5PYDataset):
    '''The CamVid motion based segmentation dataset
    The Cambridge-driving Labeled Video Database (CamVid) [Camvid1]_ provides
    high-quality videos acquired at 30 Hz with the corresponding
    semantically labeled masks at 1 Hz and in part, 15 Hz. The ground
    truth labels associate each pixel with one of 32 semantic classes.
    This loader is intended for the SegNet version of the CamVid dataset,
    that resizes the original data to 360 by 480 resolution and remaps
    the ground truth to a subset of 11 semantic classes, plus a void
    class.
    The dataset should be downloaded from [Camvid2].
    Parameters
    ----------
    which_sets: string
        A string in ['train', 'valid', 'test'], corresponding to
        the set to be returned.
    References
    ----------
    .. [Camvid1] http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/
    .. [Camvid2]
       https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid
    '''

    filename = 'camvid.hdf5'
    default_transformers = uint8_pixels_to_floatX(('features', ))

    def __init__(self, which_sets, **kwargs):
        super(Camvid,
              self).__init__(file_or_path=find_in_data_path(self.filename),
                             which_sets=which_sets,
                             **kwargs)
예제 #5
0
class LFW(H5PYDataset):
    u"""LFW dataset.

    Labeled Faces in the Wild dataset.

    Labeled Faces in the Wild is a database of face photographs
    designed for studying the problem of unconstrained face recognition.

    http://vis-www.cs.umass.edu/lfw/

    Parameters
    ----------
    which_sets : tuple of str
        Which split to load. Valid values are 'train' and 'test',
        corresponding to the training set (50,000 examples) and the test
        set (10,000 examples).

    """
    url_dir = "https://archive.org/download/lfw_fuel/"
    filename = 'lfw.hdf5'
    default_transformers = uint8_pixels_to_floatX(('features', ))

    def __init__(self, which_sets, **kwargs):
        kwargs.setdefault('load_in_memory', True)
        super(LFW,
              self).__init__(file_or_path=find_in_data_path(self.filename),
                             which_sets=which_sets,
                             **kwargs)
예제 #6
0
def get_dataset_iterator(dataset,
                         split,
                         include_features=True,
                         include_targets=False,
                         unit_scale=True):
    """Get iterator for dataset, split, targets (labels) and scaling (from 255 to 1.0)"""
    sources = []
    sources = sources + ['features'] if include_features else sources
    sources = sources + ['targets'] if include_targets else sources
    if split == "all":
        splits = ('train', 'valid', 'test')
    elif split == "nontrain":
        splits = ('valid', 'test')
    else:
        splits = (split, )

    dataset_fname = find_in_data_path("{}.hdf5".format(dataset))
    datastream = H5PYDataset(dataset_fname, which_sets=splits, sources=sources)
    if unit_scale:
        datastream.default_transformers = uint8_pixels_to_floatX(
            ('features', ))

    train_stream = DataStream.default_stream(
        dataset=datastream,
        iteration_scheme=SequentialExampleScheme(datastream.num_examples))

    it = train_stream.get_epoch_iterator()
    return it
예제 #7
0
파일: mnist.py 프로젝트: xiaoyexixi/fuel
class MNIST(H5PYDataset):
    u"""MNIST dataset.

    MNIST (Mixed National Institute of Standards and Technology) [LBBH] is
    a database of handwritten digits. It is one of the most famous
    datasets in machine learning and consists of 60,000 training images
    and 10,000 testing images. The images are grayscale and 28 x 28 pixels
    large. It is accessible through Yann LeCun's website [LECUN].

    .. [LBBH] Yann LeCun, Léon Bottou, Yoshua Bengio, and Patrick Haffner,
       *Gradient-based learning applied to document recognition*,
       Proceedings of the IEEE, November 1998, 86(11):2278-2324.

    .. [LECUN] http://yann.lecun.com/exdb/mnist/

    Parameters
    ----------
    which_set : 'train' or 'test'
        Whether to load the training set (60,000 samples) or the test set
        (10,000 samples).

    """
    filename = 'mnist.hdf5'
    default_transformers = uint8_pixels_to_floatX(('features', ))

    def __init__(self, which_set, **kwargs):
        kwargs.setdefault('load_in_memory', True)
        super(MNIST, self).__init__(self.data_path, which_set, **kwargs)

    @property
    def data_path(self):
        return os.path.join(config.data_path, self.filename)
예제 #8
0
class MMImdbDataset(H5PYDataset):

    filename = 'multimodal_imdb.hdf5'
    default_transformers = ((Padding, [], {
        'mask_sources': ('sequences', )
    }), (SequenceTransposer, [], {
        'which_sources': ('sequences', 'sequences_mask')
    })) + uint8_pixels_to_floatX(('images', ))

    def __init__(self, which_sets, **kwargs):

        #kwargs.setdefault('file_or_path', MMImdbDataset.get_filepath())
        kwargs.setdefault('sources', ('features', 'genres'))
        super(MMImdbDataset, self).__init__(which_sets=which_sets, **kwargs)

    @staticmethod
    def get_filepath(filename=None):
        if filename is None:
            filename = MMImdbDataset.filename
        return find_in_data_path(filename)

    def create_stream(self, batch_size=None):
        if batch_size is None:
            batch_size = self.num_examples
        return DataStream.default_stream(dataset=self,
                                         iteration_scheme=ShuffledScheme(
                                             examples=self.num_examples,
                                             batch_size=batch_size))

    def get_target_names(filename=None):
        if filename is None:
            filename = MMImdbDataset.get_filepath()
        with h5py.File(filename, 'r') as f:
            target_names = json.loads(f['genres'].attrs['target_names'])
        return target_names
예제 #9
0
class Flower(H5PYDataset):
    filename = 'flowers102_32x32.hdf5'
    default_transformers = uint8_pixels_to_floatX(('features',))

    def __init__(self, which_sets, **kwargs):
        kwargs.setdefault('load_in_memory', True)
        super(Flower, self).__init__(
            file_or_path=find_in_data_path(self.filename),
            which_sets=which_sets, **kwargs)
예제 #10
0
class JOS(H5PYDataset):

    filename = 'jos.hdf5'

    default_transformers = uint8_pixels_to_floatX(('features', ))

    def __init__(self, which_sets, **kwargs):
        kwargs.setdefault('load_in_memory', False)
        super(JOS,
              self).__init__(file_or_path=find_in_data_path(self.filename),
                             which_sets=which_sets,
                             **kwargs)
예제 #11
0
def get_all_data_inorder(filename, batch_size):
    sources = ('features', 'targets')

    dataset_fname = find_in_data_path(filename+'.hdf5')
    data_all = H5PYDataset(dataset_fname, which_sets=['train', 'valid', 'test'],
                             sources=sources)
    data_all.default_transformers = uint8_pixels_to_floatX(('features',))
    main_stream = DataStream.default_stream(
        dataset=data_all,
        iteration_scheme=SequentialScheme(data_all.num_examples, batch_size))
    color_stream = Colorize(main_stream, which_sources=('features',))
    return data_all.num_examples, color_stream
예제 #12
0
class LORIS(H5PYDataset):
    """The LORIS dataset of brain images in the trial.
    """
    filename = 'Loris_data.hdf5'
    default_transformers = uint8_pixels_to_floatX(('features', ))

    def __init__(self, which_sets, **kwargs):
        kwargs.setdefault('load_in_memory', True)
        super(LORIS,
              self).__init__(file_or_path=find_in_data_path(self.filename),
                             which_sets=which_sets,
                             **kwargs)
예제 #13
0
def get_all_data_inorder(filename, batch_size):
    sources = ('features', 'targets')

    dataset_fname = find_in_data_path(filename + '.hdf5')
    data_all = H5PYDataset(dataset_fname,
                           which_sets=['train', 'valid', 'test'],
                           sources=sources)
    data_all.default_transformers = uint8_pixels_to_floatX(('features', ))
    main_stream = DataStream.default_stream(dataset=data_all,
                                            iteration_scheme=SequentialScheme(
                                                data_all.num_examples,
                                                batch_size))
    color_stream = Colorize(main_stream, which_sources=('features', ))
    return data_all.num_examples, color_stream
예제 #14
0
def get_dataset_iterator(dataset,
                         split,
                         include_features=True,
                         include_targets=False,
                         unit_scale=True,
                         label_transforms=False,
                         return_length=False):
    """Get iterator for dataset, split, targets (labels) and scaling (from 255 to 1.0)"""
    sources = []
    sources = sources + ['features'] if include_features else sources
    sources = sources + ['targets'] if include_targets else sources
    if split == "all":
        splits = ('train', 'valid', 'test')
    elif split == "nontrain":
        splits = ('valid', 'test')
    else:
        splits = (split, )

    dataset_fname = find_in_data_path("{}.hdf5".format(dataset))
    h5_dataset = H5PYDataset(dataset_fname, which_sets=splits, sources=sources)
    if unit_scale:
        h5_dataset.default_transformers = uint8_pixels_to_floatX(
            ('features', ))

    datastream = DataStream.default_stream(
        dataset=h5_dataset,
        iteration_scheme=SequentialExampleScheme(h5_dataset.num_examples))

    if label_transforms:
        # TODO: maybe refactor this common bit with get_custom_streams below
        datastream = AddLabelUncertainty(datastream,
                                         chance=0,
                                         which_sources=('targets', ))

        datastream = RandomLabelStrip(datastream,
                                      chance=0,
                                      which_sources=('targets', ))

        # HACK: allow variable stretch
        datastream = StretchLabels(datastream,
                                   length=128,
                                   which_sources=('targets', ))

    it = datastream.get_epoch_iterator()
    if return_length:
        return it, h5_dataset.num_examples
    else:
        return it
예제 #15
0
class SVHN(H5PYDataset):
    """The Street View House Numbers (SVHN) dataset.

    SVHN [SVHN] is a real-world image dataset for developing machine
    learning and object recognition algorithms with minimal requirement
    on data preprocessing and formatting. It can be seen as similar in
    flavor to MNIST [LBBH] (e.g., the images are of small cropped
    digits), but incorporates an order of magnitude more labeled data
    (over 600,000 digit images) and comes from a significantly harder,
    unsolved, real world problem (recognizing digits and numbers in
    natural scene images). SVHN is obtained from house numbers in
    Google Street View images.

    .. [SVHN] Yuval Netzer, Tao Wang, Adam Coates, Alessandro Bissacco,
       Bo Wu, Andrew Y. Ng. *Reading Digits in Natural Images with
       Unsupervised Feature Learning*, NIPS Workshop on Deep Learning
       and Unsupervised Feature Learning, 2011.

    .. [LBBH] Yann LeCun, Léon Bottou, Yoshua Bengio, and Patrick Haffner,
       *Gradient-based learning applied to document recognition*,
       Proceedings of the IEEE, November 1998, 86(11):2278-2324.

    Parameters
    ----------
    which_format : {1, 2}
        SVHN format 1 contains the full numbers, whereas SVHN format 2
        contains cropped digits.
    which_sets : tuple of str
        Which split to load. Valid values are 'train', 'test' and 'extra',
        corresponding to the training set (73,257 examples), the test
        set (26,032 examples) and the extra set (531,131 examples).
        Note that SVHN does not have a validation set; usually you will
        create your own training/validation split using the `subset`
        argument.

    """
    _filename = 'svhn_format_{}.hdf5'
    default_transformers = uint8_pixels_to_floatX(('features',))

    def __init__(self, which_format, which_sets, **kwargs):
        self.which_format = which_format
        super(SVHN, self).__init__(
            file_or_path=find_in_data_path(self.filename),
            which_sets=which_sets, **kwargs)

    @property
    def filename(self):
        return self._filename.format(self.which_format)
예제 #16
0
class TinyILSVRC2012(H5PYDataset):
    """The Tiny ILSVRC2012 Dataset.
    Parameters
    ----------
    which_sets : tuple of str
        Which split to load. Valid values are 'train' (1,281,167 examples)
        'valid' (50,000 examples), and 'test' (100,000 examples).
    """
    filename = 'ilsvrc2012_tiny.hdf5'
    default_transformers = uint8_pixels_to_floatX(('features',))

    def __init__(self, which_sets, **kwargs):
        kwargs.setdefault('load_in_memory', False)
        super(TinyILSVRC2012, self).__init__(
            file_or_path=find_in_data_path(self.filename),
            which_sets=which_sets, **kwargs)
예제 #17
0
파일: celeba.py 프로젝트: zhoujian1210/fuel
class CelebA(H5PYDataset):
    """The CelebFaces Attributes Dataset (CelebA) dataset.

    CelebA is a large-scale face
    attributes dataset with more than 200K celebrity images, each
    with 40 attribute annotations. The images in this dataset cover
    large pose variations and background clutter. CelebA has large
    diversities, large quantities, and rich annotations, including:

    * 10,177 number of identities
    * 202,599 number of face images
    * 5 landmark locations per image
    * 40 binary attributes annotations per image.

    The dataset can be employed as the training and test sets for
    the following computer vision tasks:

    * face attribute recognition
    * face detection
    * landmark (or facial part) localization

    Parameters
    ----------
    which_format : {'aligned_cropped, '64'}
        Either the aligned and cropped version of CelebA, or
        a 64x64 version of it.
    which_sets : tuple of str
        Which split to load. Valid values are 'train', 'valid' and
        'test' corresponding to the training set (162,770 examples), the
        validation set (19,867 examples) and the test set (19,962
        examples).

    """
    _filename = 'celeba_{}.hdf5'
    default_transformers = uint8_pixels_to_floatX(('features', ))

    def __init__(self, which_format, which_sets, **kwargs):
        self.which_format = which_format
        super(CelebA,
              self).__init__(file_or_path=find_in_data_path(self.filename),
                             which_sets=which_sets,
                             **kwargs)

    @property
    def filename(self):
        return self._filename.format(self.which_format)
예제 #18
0
파일: fuel_helper.py 프로젝트: dribnet/plat
def get_dataset_iterator(dataset, split, include_features=True, include_targets=False, unit_scale=True, label_transforms=False, return_length=False):
    """Get iterator for dataset, split, targets (labels) and scaling (from 255 to 1.0)"""
    sources = []
    sources = sources + ['features'] if include_features else sources
    sources = sources + ['targets'] if include_targets else sources
    if split == "all":
        splits = ('train', 'valid', 'test')
    elif split == "nontrain":
        splits = ('valid', 'test')
    else:
        splits = (split,)

    dataset_fname = find_in_data_path("{}.hdf5".format(dataset))
    h5_dataset = H5PYDataset(dataset_fname, which_sets=splits,
                             sources=sources)
    if unit_scale:
        h5_dataset.default_transformers = uint8_pixels_to_floatX(('features',))

    datastream = DataStream.default_stream(
        dataset=h5_dataset,
        iteration_scheme=SequentialExampleScheme(h5_dataset.num_examples))

    if label_transforms:
        # TODO: maybe refactor this common bit with get_custom_streams below
        datastream = AddLabelUncertainty(datastream,
                                         chance=0,
                                         which_sources=('targets',))

        datastream = RandomLabelStrip(datastream,
                                         chance=0,
                                         which_sources=('targets',))

        # HACK: allow variable stretch
        datastream = StretchLabels(datastream,
                                         length=128,
                                         which_sources=('targets',))


    it = datastream.get_epoch_iterator()
    if return_length:
        return it, h5_dataset.num_examples
    else:
        return it
예제 #19
0
파일: svhn.py 프로젝트: zhoujian1210/fuel
class SVHN(H5PYDataset):
    """The Street View House Numbers (SVHN) dataset.

    SVHN [SVHN] is a real-world image dataset for developing machine
    learning and object recognition algorithms with minimal requirement
    on data preprocessing and formatting. It can be seen as similar in
    flavor to MNIST [LBBH] (e.g., the images are of small cropped
    digits), but incorporates an order of magnitude more labeled data
    (over 600,000 digit images) and comes from a significantly harder,
    unsolved, real world problem (recognizing digits and numbers in
    natural scene images). SVHN is obtained from house numbers in
    Google Street View images.

    Parameters
    ----------
    which_format : {1, 2}
        SVHN format 1 contains the full numbers, whereas SVHN format 2
        contains cropped digits.
    which_sets : tuple of str
        Which split to load. Valid values are 'train', 'test' and 'extra',
        corresponding to the training set (73,257 examples), the test
        set (26,032 examples) and the extra set (531,131 examples).
        Note that SVHN does not have a validation set; usually you will
        create your own training/validation split using the `subset`
        argument.

    """
    _filename = 'svhn_format_{}.hdf5'
    default_transformers = uint8_pixels_to_floatX(('features', ))

    def __init__(self, which_format, which_sets, **kwargs):
        self.which_format = which_format
        super(SVHN,
              self).__init__(file_or_path=find_in_data_path(self.filename),
                             which_sets=which_sets,
                             **kwargs)

    @property
    def filename(self):
        return self._filename.format(self.which_format)
예제 #20
0
class IAM_ONDB(H5PYDataset):
    """The iam_ondb of online images.

    Parameters
    ----------
    which_sets : tuple of str
        Which split to load. Valid values are 'train' and 'test'.
    Notes
    -----
    Users can create their own
    training / validation split using the `subset` argument.

    """
    filename = 'iam_ondb.hdf5'

    default_transformers = uint8_pixels_to_floatX(('image_features', ))

    def __init__(self, which_sets, **kwargs):
        super(IAM_ONDB,
              self).__init__(file_or_path=find_in_data_path(self.filename),
                             which_sets=which_sets,
                             **kwargs)
예제 #21
0
def get_dataset_iterator(dataset, split, include_features=True, include_targets=False, unit_scale=True):
    """Get iterator for dataset, split, targets (labels) and scaling (from 255 to 1.0)"""
    sources = []
    sources = sources + ['features'] if include_features else sources
    sources = sources + ['targets'] if include_targets else sources
    if split == "all":
        splits = ('train', 'valid', 'test')
    elif split == "nontrain":
        splits = ('valid', 'test')
    else:
        splits = (split,)

    dataset_fname = find_in_data_path("{}.hdf5".format(dataset))
    datastream = H5PYDataset(dataset_fname, which_sets=splits,
                             sources=sources)
    if unit_scale:
        datastream.default_transformers = uint8_pixels_to_floatX(('features',))

    train_stream = DataStream.default_stream(
        dataset=datastream,
        iteration_scheme=SequentialExampleScheme(datastream.num_examples))

    it = train_stream.get_epoch_iterator()
    return it
예제 #22
0
파일: fuel_helper.py 프로젝트: dribnet/plat
def create_custom_streams(filename, training_batch_size, monitoring_batch_size,
                          include_targets=False, color_convert=False,
                          allowed=None, stretch=None, random_spread=False,
                          random_label_strip=False, add_label_uncertainty=False,
                          uuid_str=None,
                          split_names=['train', 'valid', 'test']):
    """Creates data streams from fuel hdf5 file.

    Currently features must be 64x64.

    Parameters
    ----------
    filename : string
        basename to hdf5 file for input
    training_batch_size : int
        Batch size for training.
    monitoring_batch_size : int
        Batch size for monitoring.
    include_targets : bool
        If ``True``, use both features and targets. If ``False``, use
        features only.
    color_convert : bool
        If ``True``, input is assumed to be one-channel, and so will
        be transformed to three-channel by duplication.

    Returns
    -------
    rval : tuple of data streams
        Data streams for the main loop, the training set monitor,
        the validation set monitor and the test set monitor.

    """
    sources = ('features', 'targets') if include_targets else ('features',)

    dataset_fname = find_in_data_path(filename+'.hdf5')
    data_train = H5PYDataset(dataset_fname, which_sets=[split_names[0]],
                             sources=sources)
    data_valid = H5PYDataset(dataset_fname, which_sets=[split_names[1]],
                             sources=sources)
    data_test = H5PYDataset(dataset_fname, which_sets=[split_names[2]],
                            sources=sources)
    data_train.default_transformers = uint8_pixels_to_floatX(('features',))
    data_valid.default_transformers = uint8_pixels_to_floatX(('features',))
    data_test.default_transformers = uint8_pixels_to_floatX(('features',))

    results = create_streams(data_train, data_valid, data_test,
                             training_batch_size, monitoring_batch_size)

    if color_convert:
        results = tuple(map(
                    lambda s: Colorize(s, which_sources=('features',)),
                    results))

    if add_label_uncertainty:
        results = tuple(map(
                    lambda s: AddLabelUncertainty(s, chance=add_label_uncertainty,
                                       which_sources=('targets',)),
                    results))

    if random_label_strip:
        results = tuple(map(
                    lambda s: RandomLabelStrip(s, chance=random_label_strip,
                                       which_sources=('targets',)),
                    results))

    # wrap labels in stretcher if requested
    if stretch is not None:
        results = tuple(map(
                    lambda s: StretchLabels(s, which_sources=('targets',), length=stretch),
                    results))

    # wrap labels in scrubber if not all labels are allowed
    if allowed:
        results = tuple(map(
                    lambda s: Scrubber(s, allowed=allowed,
                                       which_sources=('targets',)),
                    results))

    if random_spread:
        results = tuple(map(
                    lambda s: RandomLabelOptionalSpreader(s,
                                       which_sources=('targets',)),
                    results))

    if uuid_str is not None:
        results = tuple(map(
                    lambda s: UUIDStretch(s, uuid_str=uuid_str,
                                       which_sources=('targets',)),
                    results))

    return results
예제 #23
0
def create_custom_streams(filename,
                          training_batch_size,
                          monitoring_batch_size,
                          include_targets=False,
                          color_convert=False,
                          allowed=None,
                          stretch=False,
                          split_names=['train', 'valid', 'test']):
    """Creates data streams from fuel hdf5 file.

    Currently features must be 64x64.

    Parameters
    ----------
    filename : string
        basename to hdf5 file for input
    training_batch_size : int
        Batch size for training.
    monitoring_batch_size : int
        Batch size for monitoring.
    include_targets : bool
        If ``True``, use both features and targets. If ``False``, use
        features only.
    color_convert : bool
        If ``True``, input is assumed to be one-channel, and so will
        be transformed to three-channel by duplication.

    Returns
    -------
    rval : tuple of data streams
        Data streams for the main loop, the training set monitor,
        the validation set monitor and the test set monitor.

    """
    sources = ('features', 'targets') if include_targets else ('features', )

    dataset_fname = find_in_data_path(filename + '.hdf5')
    data_train = H5PYDataset(dataset_fname,
                             which_sets=[split_names[0]],
                             sources=sources)
    data_valid = H5PYDataset(dataset_fname,
                             which_sets=[split_names[1]],
                             sources=sources)
    data_test = H5PYDataset(dataset_fname,
                            which_sets=[split_names[2]],
                            sources=sources)
    data_train.default_transformers = uint8_pixels_to_floatX(('features', ))
    data_valid.default_transformers = uint8_pixels_to_floatX(('features', ))
    data_test.default_transformers = uint8_pixels_to_floatX(('features', ))

    results = create_streams(data_train, data_valid, data_test,
                             training_batch_size, monitoring_batch_size)

    if color_convert:
        results = tuple(
            map(lambda s: Colorize(s, which_sources=('features', )), results))

    # wrap labels in stretcher if requested
    if stretch:
        results = tuple(
            map(lambda s: StretchLabels(s, which_sources=('targets', )),
                results))

    # wrap labels in scrubber if not all labels are allowed
    if allowed:
        results = tuple(
            map(
                lambda s: Scrubber(
                    s, allowed=allowed, which_sources=('targets', )), results))

    return results