Exemplo n.º 1
0
def test_iterate_scheme():
    from fuel.datasets import IndexableDataset
    from fuel.schemes import (SequentialScheme, ShuffledScheme,SequentialExampleScheme, ShuffledExampleScheme)

    seed = 1234
    rng = numpy.random.RandomState(seed)
    features = rng.randint(256, size=(8, 2, 2))
    targets = rng.randint(4, size=(8, 1))

    dataset = IndexableDataset(indexables=OrderedDict([('features', features),
                                                       ('targets', targets)]),
                               axis_labels=OrderedDict([('features', ('batch', 'height', 'width')),
                                                        ('targets', ('batch', 'index'))]))

    schemes = [SequentialScheme(examples=8, batch_size=5),
               ShuffledScheme(examples=8, batch_size=3),
               SequentialExampleScheme(examples=8),
               ShuffledExampleScheme(examples=8)]

    # for scheme in schemes:
    #     print(list(scheme.get_request_iterator()))

    state = dataset.open()
    scheme = ShuffledScheme(examples=dataset.num_examples, batch_size=3)

    for request in scheme.get_request_iterator():
        data = dataset.get_data(state=state, request=request)
        print(data[0].shape, data[1].shape)

    dataset.close(state)
Exemplo n.º 2
0
def get_datastream(path, which_set='train_si84', batch_size=1):
    wsj_dataset = H5PYDataset(path, which_sets=(which_set, ))
    print path, which_set
    iterator_scheme = ShuffledScheme(batch_size=batch_size, examples=wsj_dataset.num_examples)
    base_stream = DataStream(dataset=wsj_dataset,
                             iteration_scheme=iterator_scheme)
    fs = FilterSources(data_stream=base_stream, sources=['features', 'ivectors', 'targets'])
    padded_stream = Padding(data_stream=fs)
    return padded_stream
Exemplo n.º 3
0
def create_streams(train_set, valid_set, test_set, training_batch_size,
                   monitoring_batch_size):
    """Creates data streams for training and monitoring.

    Parameters
    ----------
    train_set : :class:`fuel.datasets.Dataset`
        Training set.
    valid_set : :class:`fuel.datasets.Dataset`
        Validation set.
    test_set : :class:`fuel.datasets.Dataset`
        Test set.
    monitoring_batch_size : int
        Batch size for monitoring.
    include_targets : bool
        If ``True``, use both features and targets. If ``False``, use
        features only.

    Returns
    -------
    rval : tuple of data streams
        Data streams for the main loop, the training set monitor,
        the validation set monitor and the test set monitor.

    """
    main_loop_stream = DataStream.default_stream(
        dataset=train_set,
        iteration_scheme=ShuffledScheme(train_set.num_examples,
                                        training_batch_size))
    train_monitor_stream = DataStream.default_stream(
        dataset=train_set,
        iteration_scheme=ShuffledScheme(train_set.num_examples,
                                        monitoring_batch_size))
    valid_monitor_stream = DataStream.default_stream(
        dataset=valid_set,
        iteration_scheme=ShuffledScheme(valid_set.num_examples,
                                        monitoring_batch_size))
    test_monitor_stream = DataStream.default_stream(
        dataset=test_set,
        iteration_scheme=ShuffledScheme(test_set.num_examples,
                                        monitoring_batch_size))

    return (main_loop_stream, train_monitor_stream, valid_monitor_stream,
            test_monitor_stream)
Exemplo n.º 4
0
def get_mnist_streams(num_train_examples, batch_size):
    from fuel.datasets import MNIST
    dataset = MNIST(("train", ))
    all_ind = numpy.arange(dataset.num_examples)
    rng = numpy.random.RandomState(seed=1)
    rng.shuffle(all_ind)

    indices_train = all_ind[:num_train_examples]
    indices_valid = all_ind[num_train_examples:]

    tarin_stream = Flatten(DataStream.default_stream(
        dataset, iteration_scheme=ShuffledScheme(indices_train, batch_size)),
                           which_sources=('features', ))

    valid_stream = Flatten(DataStream.default_stream(
        dataset, iteration_scheme=ShuffledScheme(indices_valid, batch_size)),
                           which_sources=('features', ))

    return tarin_stream, valid_stream
Exemplo n.º 5
0
def create_cifar10_data_streams(batch_size, monitoring_batch_size, rng=None):
    train_set = CIFAR10(('train', ),
                        sources=('features', ),
                        subset=slice(0, 45000))
    valid_set = CIFAR10(('train', ),
                        sources=('features', ),
                        subset=slice(45000, 50000))
    main_loop_stream = DataStream.default_stream(
        train_set,
        iteration_scheme=ShuffledScheme(train_set.num_examples,
                                        batch_size,
                                        rng=rng))
    train_monitor_stream = DataStream.default_stream(
        train_set,
        iteration_scheme=ShuffledScheme(5000, monitoring_batch_size, rng=rng))
    valid_monitor_stream = DataStream.default_stream(
        valid_set,
        iteration_scheme=ShuffledScheme(5000, monitoring_batch_size, rng=rng))
    return main_loop_stream, train_monitor_stream, valid_monitor_stream
Exemplo n.º 6
0
 def get_datastream(self, kind, indices):
     split = {
         'trn': self.trn,
         'val': self.val,
         'tst': self.tst,
     }[kind]
     indices = indices if indices is not None else split.ind
     assert len(set(indices) - set(split.ind)) == 0, 'requested indices outside of split'
     ds = DataStream.default_stream(
         split.set, iteration_scheme=ShuffledScheme(indices, split.batch_size))
     return ds
Exemplo n.º 7
0
def create_flower_data_streams(
        batch_size, monitoring_batch_size, rng=None, random=True):
    # Since it's so small just use the entire dataset.
    train_set = Flower(
        ('train',), sources=('features',), subset=slice(0, 114646))
    valid_set = Flower(
        ('train',), sources=('features',), subset=slice(114646, 122835))
    main_loop_stream = DataStream.default_stream(
        train_set,
        iteration_scheme=ShuffledScheme(
            train_set.num_examples, batch_size, rng=rng))
    train_monitor_stream = DataStream.default_stream(
        train_set,
        iteration_scheme=ShuffledScheme(
            5000, monitoring_batch_size, rng=rng))
    valid_monitor_stream = DataStream.default_stream(
        valid_set,
        iteration_scheme=ShuffledScheme(
            5000, monitoring_batch_size, rng=rng))
    return main_loop_stream, train_monitor_stream, valid_monitor_stream
Exemplo n.º 8
0
    def get_curr_out(self, model, example_set, batch_size=256):
        scheme = ShuffledScheme(examples=example_set.num_examples,
                                batch_size=batch_size)
        example_state = example_set.open()
        x = T.matrix('features')
        out = model.apply(x)
        pred_fn = theano.function([x], out)
        y = np.zeros((example_set.num_examples))
        y_hat = np.zeros((example_set.num_examples))
        for idx, request in enumerate(scheme.get_request_iterator()):
            data = example_set.get_data(state=example_state, request=request)
            out_val = pred_fn(data[0].astype(np.float32))
            end_idx = (idx + 1) * batch_size
            if end_idx < example_set.num_examples:
                y[idx * batch_size:end_idx] = data[1].flatten()
                y_hat[idx * batch_size:end_idx] = out_val.flatten()
        P = np.sum(y)
        N = np.sum(1 - y)

        return (y, y_hat), P, N
Exemplo n.º 9
0
def get_cmv_v2_64_len10_streams(batch_size):
    path = '/data/lisatmp3/cooijmat/datasets/cmv/cmv20x64x64_png.hdf5'
    train_dataset = CMVv2(path=path, which_set="train")
    valid_dataset = CMVv2(path=path, which_set="valid")
    train_ind = numpy.arange(train_dataset.num_examples)
    valid_ind = numpy.arange(valid_dataset.num_examples)
    rng = numpy.random.RandomState(seed=1)
    rng.shuffle(train_ind)
    rng.shuffle(valid_ind)

    train_datastream = DataStream.default_stream(
        train_dataset, iteration_scheme=ShuffledScheme(train_ind, batch_size))
    train_datastream = Preprocessor_CMV_v2(train_datastream, 10)

    valid_datastream = DataStream.default_stream(
        valid_dataset, iteration_scheme=ShuffledScheme(valid_ind, batch_size))
    valid_datastream = Preprocessor_CMV_v2(valid_datastream, 10)

    train_datastream.sources = ('features', 'targets')
    valid_datastream.sources = ('features', 'targets')
Exemplo n.º 10
0
def _liacl_data_stream(dataset,
                       rel2index,
                       batch_size,
                       word2index,
                       target='negative_sampling',
                       name="",
                       k=3,
                       shuffle=False,
                       neg_sample_kwargs={}):
    batches_per_epoch = int(np.ceil(dataset.num_examples / float(batch_size)))
    if shuffle:
        iteration_scheme = ShuffledScheme(dataset.num_examples, batch_size)
    else:
        iteration_scheme = SequentialScheme(dataset.num_examples, batch_size)
    data_stream = DataStream(dataset, iteration_scheme=iteration_scheme)
    data_stream = NumberizeWords(data_stream,
                                 word2index,
                                 default=word2index[UNKNOWN_TOKEN],
                                 which_sources=('head', 'tail'))
    data_stream = NumberizeWords(data_stream, rel2index, which_sources=('rel'))

    if target == "score":
        data_stream = Rename(data_stream, {'score': 'target'})
    else:
        data_stream = FilterSources(data_stream,
                                    sources=('head', 'tail', 'rel'))

    data_stream = Padding(data_stream,
                          mask_sources=('head, tail'),
                          mask_dtype=np.float32)

    if target == 'negative_sampling':
        logger.info('target for data stream ' + str(name) +
                    ' is negative sampling')
        data_stream = NegativeSampling(data_stream, k=k)
    elif target == 'filtered_negative_sampling':
        logger.info('target for data stream ' + str(name) +
                    ' is filtered negative sampling')
        data_stream = FilteredNegativeSampling(data_stream,
                                               k=k,
                                               **neg_sample_kwargs)
    elif target == 'score':
        logger.info('target for data stream ' + str(name) + ' is score')
    else:
        raise NotImplementedError(
            'target ', target,
            ' must be one of "score" or "negative_sampling"')

    data_stream = MergeSource(data_stream,
                              merge_sources=('head', 'tail', 'head_mask',
                                             'tail_mask', 'rel'),
                              merge_name='input')

    return data_stream, batches_per_epoch
Exemplo n.º 11
0
def create_packing_gaussian_mixture_data_streams(num_packings,
                                                 batch_size,
                                                 monitoring_batch_size,
                                                 means=None,
                                                 variances=None,
                                                 priors=None,
                                                 rng=None,
                                                 num_examples=100000,
                                                 sources=('features', )):

    train_set = GaussianPackingMixture(num_packings=num_packings,
                                       num_examples=num_examples,
                                       means=means,
                                       variances=variances,
                                       priors=priors,
                                       rng=rng,
                                       sources=sources)

    valid_set = GaussianPackingMixture(num_packings=num_packings,
                                       num_examples=num_examples,
                                       means=means,
                                       variances=variances,
                                       priors=priors,
                                       rng=rng,
                                       sources=sources)

    main_loop_stream = DataStream(train_set,
                                  iteration_scheme=ShuffledScheme(
                                      train_set.num_examples,
                                      batch_size=batch_size,
                                      rng=rng))

    train_monitor_stream = DataStream(train_set,
                                      iteration_scheme=ShuffledScheme(
                                          5000, batch_size, rng=rng))

    valid_monitor_stream = DataStream(valid_set,
                                      iteration_scheme=ShuffledScheme(
                                          5000, batch_size, rng=rng))

    return main_loop_stream, train_monitor_stream, valid_monitor_stream
Exemplo n.º 12
0
def get_stream(trainXY, batch_size=100):
    #trainXY=genSynXY()
    dataset_train = IndexableDataset(trainXY)
    stream_train_1 = DataStream(dataset=dataset_train,
                                iteration_scheme=ShuffledScheme(
                                    examples=dataset_train.num_examples,
                                    batch_size=batch_size))
    stream_train_2 = Padding(stream_train_1)
    #stream_train_1.sources=('x_mask_o', 'y_mask_o', 'x', 'y')
    stream_train_3 = Mapping(stream_train_2, transpose_stream)

    return (stream_train_3, dataset_train.num_examples)
Exemplo n.º 13
0
def get_datastream(dataset, batch_size=160):
    dataset = DataStream(
        dataset,
        iteration_scheme=ShuffledScheme(dataset.num_examples, batch_size),
    )
    dataset = Padding(dataset)

    def _transpose(data):
        return tuple(np.rollaxis(array, 1, 0) for array in data)

    dataset = Mapping(dataset, _transpose)
    return dataset
Exemplo n.º 14
0
def create_spiral_data_streams(batch_size, monitoring_batch_size, rng=None,
                               num_examples=100000, classes=1, cycles=2,
                               noise=0.1):
    train_set = Spiral(num_examples=num_examples, classes=classes,
                       cycles=cycles, noise=noise, sources=('features',))

    valid_set = Spiral(num_examples=num_examples, classes=classes,
                       cycles=cycles, noise=noise, sources=('features',))

    main_loop_stream = DataStream.default_stream(
        train_set,
        iteration_scheme=ShuffledScheme(
            train_set.num_examples, batch_size=batch_size, rng=rng))

    train_monitor_stream = DataStream.default_stream(
        train_set, iteration_scheme=ShuffledScheme(5000, batch_size, rng=rng))

    valid_monitor_stream = DataStream.default_stream(
        valid_set, iteration_scheme=ShuffledScheme(5000, batch_size, rng=rng))

    return main_loop_stream, train_monitor_stream, valid_monitor_stream
Exemplo n.º 15
0
    def create_stream(self, batch_size, is_train):
        if is_train:
            iter_scheme = ShuffledScheme(self.num_examples, batch_size)
        else:
            iter_scheme = SequentialScheme(self.num_examples, batch_size)
        stream = DataStream(self, iteration_scheme=iter_scheme)

        # for d in stream.get_epoch_iterator(as_dict=True):
        #     for k, v in d.iteritems():
        #         print('{}:\t{},\t{}'.format(k, v.dtype, v.shape))
        #     break

        return stream
Exemplo n.º 16
0
def get_stream(hdf5_file, which_set, batch_size=None):
    dataset = H5PYDataset(hdf5_file,
                          which_sets=(which_set, ),
                          load_in_memory=True)
    if batch_size == None:
        batch_size = dataset.num_examples
    stream = DataStream(dataset=dataset,
                        iteration_scheme=ShuffledScheme(
                            examples=dataset.num_examples,
                            batch_size=batch_size))
    # Required because Recurrent bricks receive as input [sequence, batch,
    # features]
    return Mapping(stream, transpose_stream)
Exemplo n.º 17
0
def launch_data_server(dataset, port, config):
    """
    """
    n_items = dataset.num_examples
    batch_sz = config.hyper_parameters.batch_size
    it_schm = ShuffledScheme(n_items, batch_sz)
    data_stream = DataStream(dataset=dataset, iteration_scheme=it_schm)

    try:
        start_server(data_stream, port=port, hwm=config.data_server.hwm)
    except KeyboardInterrupt as ke:
        print(ke)
    finally:
        data_stream.close()
def construct_stream(dataset, rng, batch_size, n_batches=None, **kwargs):
    """Construct data stream.

    Parameters:
    -----------
    dataset : Dataset
        Dataset to use.
    rng : numpy.random.RandomState
        Random number generator.
    batch_size : int
        Size of the batch
    n_batches : int
        Number of batchs to update population statistics.
    """
    if n_batches is not None:
        scheme = ShuffledScheme(n_batches * batch_size, batch_size=batch_size)
    else:
        scheme = ShuffledScheme(dataset.num_examples, batch_size=batch_size)
    stream = DataStream(dataset, iteration_scheme=scheme)
    stream = Mapping(stream, SortMapping(key=key))
    stream = Padding(data_stream=stream, mask_sources=['features', 'phonemes'])
    stream = Transpose(stream, [(1, 0, 2), (1, 0), (1, 0), (1, 0)])
    return stream
Exemplo n.º 19
0
def make_scheme_and_stream(dset, batchsize, shuffle=True):
    """
    dset is a Fuel `DataSet` and batchsize is an int representing the number of
    examples requested per minibatch - note assume we are always operating
    on minibatches (although they can be size 1)
    """
    if shuffle:
        scheme = ShuffledScheme(examples=dset.num_examples,
                                batch_size=batchsize)
    else:
        scheme = SequentialScheme(examples=dset.num_examples,
                                  batch_size=batchsize)
    data_stream = DataStream(dataset=dset, iteration_scheme=scheme)
    return scheme, data_stream
Exemplo n.º 20
0
def create_data(data):

    stream = DataStream(data, iteration_scheme=ShuffledScheme(data.num_examples, batch_size))

    # Data Augmentation
    stream = MinimumImageDimensions(stream, image_size, which_sources=('image_features',))
    stream = MaximumImageDimensions(stream, image_size, which_sources=('image_features',))
    stream = RandomHorizontalSwap(stream, which_sources=('image_features',))
    stream = Random2DRotation(stream, which_sources=('image_features',))

    # Data Transformation
    stream = ScaleAndShift(stream, 1./255, 0, which_sources=('image_features',))
    stream = Cast(stream, dtype='float32', which_sources=('image_features',))
    return stream
Exemplo n.º 21
0
 def test_batch_iteration_scheme_with_lists(self):
     """Batch schemes should work with more than ndarrays."""
     data = IndexableDataset(
         OrderedDict([('foo', list(range(50))), ('bar', list(range(1,
                                                                   51)))]))
     stream = DataStream(data,
                         iteration_scheme=ShuffledScheme(
                             data.num_examples, 5))
     returned = [
         sum(batches, [])
         for batches in zip(*list(stream.get_epoch_iterator()))
     ]
     assert set(returned[0]) == set(range(50))
     assert set(returned[1]) == set(range(1, 51))
Exemplo n.º 22
0
    def get_datastream(dataset, batch_size=160):
        dataset = DataStream(
            dataset,
            iteration_scheme=ShuffledScheme(dataset.num_examples, batch_size),
        )
        dataset = Padding(dataset)

        # if flatten:
        # dataset = Flatten(dataset, which_sources=('features,'))

        def _transpose(data):
            return tuple(np.rollaxis(array, 1, 0) for array in data)

        dataset = Mapping(dataset, _transpose)
        return dataset
Exemplo n.º 23
0
def get_datastream(path, norm_path, which_set='train_si84', batch_size=1):
    wsj_dataset = H5PYDataset(path, which_sets=(which_set, ))
    data_mean_std = numpy.load(norm_path)

    iterator_scheme = ShuffledScheme(batch_size=batch_size,
                                     examples=wsj_dataset.num_examples)
    base_stream = DataStream(dataset=wsj_dataset,
                             iteration_scheme=iterator_scheme)
    base_stream = Normalize(data_stream=base_stream,
                            means=data_mean_std['mean'],
                            stds=data_mean_std['std'])
    fs = FilterSources(data_stream=base_stream,
                       sources=['features', 'targets'])
    padded_stream = Padding(data_stream=fs)
    return padded_stream
Exemplo n.º 24
0
def stream_handwriting(which_sets, batch_size, seq_size, tbptt=True):
    dataset = Handwriting(which_sets)
    data_stream = DataStream.default_stream(
        dataset,
        iteration_scheme=ShuffledScheme(
            batch_size * (dataset.num_examples / batch_size), batch_size))
    data_stream = FilterSources(data_stream, sources=('features', ))
    data_stream = Padding(data_stream)
    data_stream = Mapping(data_stream, _transpose)

    if tbptt:
        data_stream = SegmentSequence(data_stream,
                                      add_flag=True,
                                      seq_size=seq_size)

    data_stream = ForceFloatX(data_stream)

    return data_stream
Exemplo n.º 25
0
    def start_server(port, which_set):
        fuel.server.logger.setLevel('WARN')

        dataset = IMDBText(which_set)
        n_train = dataset.num_examples
        stream = DataStream(
                dataset=dataset,
                iteration_scheme=ShuffledScheme(
                    examples=n_train,
                    batch_size=batch_size)
                )
        print "loading glove"
        glove = GloveTransformer(glove_version, data_stream=stream)
        padded = Padding(
                data_stream=glove,
                mask_sources=('features',)
                )

        fuel.server.start_server(padded, port=port, hwm=20)
Exemplo n.º 26
0
def get_comb_stream(fea2obj, which_set, batch_size=None, shuffle=True):
    streams = []
    for fea in fea2obj:
        obj = fea2obj[fea]
        dataset = H5PYDataset(obj.fuelfile, which_sets=(which_set,),load_in_memory=True)
        if batch_size == None: batch_size = dataset.num_examples
        if shuffle: 
            iterschema = ShuffledScheme(examples=dataset.num_examples, batch_size=batch_size)
        else: 
            iterschema = SequentialScheme(examples=dataset.num_examples, batch_size=batch_size)
        stream = DataStream(dataset=dataset, iteration_scheme=iterschema)
        if fea in seq_features:
            stream = CutInput(stream, obj.max_len)
            if obj.rec == True:
                logger.info('transforming data for recursive input')
                stream = LettersTransposer(stream, which_sources=fea)# Required because Recurrent bricks receive as input [sequence, batch,# features]
        streams.append(stream)
    stream = Merge(streams, tuple(fea2obj.keys()))
    return stream, dataset.num_examples
Exemplo n.º 27
0
 def create_data(data):
     stream = DataStream.default_stream(data,
                                        iteration_scheme=ShuffledScheme(
                                            data.num_examples, batch_size))
     stream_downscale = MinimumImageDimensions(
         stream, image_size, which_sources=('image_features', ))
     #stream_rotate = Random2DRotation(stream_downscale, which_sources=('image_features',))
     stream_max = ScikitResize(stream_downscale,
                               image_size,
                               which_sources=('image_features', ))
     stream_scale = ScaleAndShift(stream_max,
                                  1. / 255,
                                  0,
                                  which_sources=('image_features', ))
     stream_cast = Cast(stream_scale,
                        dtype='float32',
                        which_sources=('image_features', ))
     #stream_flat = Flatten(stream_scale, which_sources=('image_features',))
     return stream_cast
Exemplo n.º 28
0
def make_scheme_and_stream(dset, batchsize, msg_string, shuffle=True):
    """
    dset is a Fuel `DataSet` and batchsize is an int representing the number of
    examples requested per minibatch
    """
    if shuffle:
        print(msg_string +
              " Preparing shuffled datastream for {} examples.".format(
                  dset.num_examples))
        scheme = ShuffledScheme(examples=dset.num_examples,
                                batch_size=batchsize)
    else:
        print(msg_string +
              "Preparing sequential datastream for {} examples.".format(
                  dset.num_examples))
        scheme = SequentialScheme(examples=dset.num_examples,
                                  batch_size=batchsize)
    data_stream = DataStream(dataset=dset, iteration_scheme=scheme)
    return scheme, data_stream
Exemplo n.º 29
0
def load(batch_size, test_batch_size):
    tr_data = BinarizedMNIST(which_sets=('train', ))
    val_data = BinarizedMNIST(which_sets=('valid', ))
    test_data = BinarizedMNIST(which_sets=('test', ))

    ntrain = tr_data.num_examples
    nval = val_data.num_examples
    ntest = test_data.num_examples

    tr_scheme = ShuffledScheme(examples=ntrain, batch_size=batch_size)
    tr_stream = DataStream(tr_data, iteration_scheme=tr_scheme)

    te_scheme = SequentialScheme(examples=ntest, batch_size=test_batch_size)
    te_stream = DataStream(test_data, iteration_scheme=te_scheme)

    val_scheme = SequentialScheme(examples=nval, batch_size=batch_size)
    val_stream = DataStream(val_data, iteration_scheme=val_scheme)

    return _make_stream(tr_stream, batch_size), \
           _make_stream(val_stream, batch_size), \
           _make_stream(te_stream, test_batch_size)
Exemplo n.º 30
0
def faces(ntrain=None, nval=None, ntest=None, batch_size=128):
    path = os.path.join(data_dir, 'faces_364293_128px.hdf5')
    tr_data = H5PYDataset(path, which_sets=('train',))
    te_data = H5PYDataset(path, which_sets=('test',))

    if ntrain is None:
        ntrain = tr_data.num_examples
    if ntest is None:
        ntest = te_data.num_examples
    if nval is None:
        nval = te_data.num_examples

    tr_scheme = ShuffledScheme(examples=ntrain, batch_size=batch_size)
    tr_stream = DataStream(tr_data, iteration_scheme=tr_scheme)

    te_scheme = SequentialScheme(examples=ntest, batch_size=batch_size)
    te_stream = DataStream(te_data, iteration_scheme=te_scheme)

    val_scheme = SequentialScheme(examples=nval, batch_size=batch_size)
    val_stream = DataStream(tr_data, iteration_scheme=val_scheme)
    return tr_data, te_data, tr_stream, val_stream, te_stream
Exemplo n.º 31
0
 def batch_iterator(self, dataset, batchsize, shuffle=False):
     if isinstance(dataset, Dataset):
         if shuffle:
             train_scheme = ShuffledScheme(examples=dataset.num_examples,
                                           batch_size=batchsize)
         else:
             train_scheme = SequentialScheme(examples=dataset.num_examples,
                                             batch_size=batchsize)
         # Use `DataStream.default_stream`, otherwise the default transformers defined by the dataset *wont*
         # be applied
         stream = DataStream.default_stream(dataset=dataset,
                                            iteration_scheme=train_scheme)
         if self.fuel_stream_xform_fn is not None:
             stream = self.fuel_stream_xform_fn(stream)
         return stream.get_epoch_iterator()
     elif _is_sequence_of_arrays(dataset):
         return iterate_minibatches(dataset, batchsize, shuffle=shuffle)
     else:
         raise TypeError(
             'dataset should be a fuel Dataset instance or a list of arrays'
         )