예제 #1
0
def test_shared_variable_modifier():
    weights = numpy.array([-1, 1], dtype=theano.config.floatX)
    features = [numpy.array(f, dtype=theano.config.floatX) for f in [[1, 2], [3, 4], [5, 6]]]
    targets = [(weights * f).sum() for f in features]
    n_batches = 3
    dataset = IterableDataset(dict(features=features, targets=targets))

    x = tensor.vector("features")
    y = tensor.scalar("targets")
    W = shared_floatx([0, 0], name="W")
    cost = ((x * W).sum() - y) ** 2
    cost.name = "cost"

    step_rule = Scale(0.001)
    sgd = GradientDescent(cost=cost, parameters=[W], step_rule=step_rule)
    main_loop = MainLoop(
        model=None,
        data_stream=dataset.get_example_stream(),
        algorithm=sgd,
        extensions=[
            FinishAfter(after_n_epochs=1),
            SharedVariableModifier(step_rule.learning_rate, lambda n: numpy.cast[theano.config.floatX](10.0 / n)),
        ],
    )

    main_loop.run()

    assert_allclose(step_rule.learning_rate.get_value(), numpy.cast[theano.config.floatX](10.0 / n_batches))
예제 #2
0
def test_ngram_stream_error_on_multiple_sources():
    # Check that NGram accepts only data streams with one source
    sentences = [list(numpy.random.randint(10, size=sentence_length))
                 for sentence_length in [3, 5, 7]]
    stream = IterableDataset(sentences).get_example_stream()
    stream.sources = ('1', '2')
    assert_raises(ValueError, NGrams, 4, stream)
예제 #3
0
def setup_mainloop(extension):
    """Set up a simple main loop for progress bar tests.

    Create a MainLoop, register the given extension, supply it with a
    DataStream and a minimal model/cost to optimize.

    """
    features = [numpy.array(f, dtype=theano.config.floatX)
                for f in [[1, 2], [3, 4], [5, 6]]]
    dataset = IterableDataset(dict(features=features))

    W = shared_floatx([0, 0], name='W')
    x = tensor.vector('features')
    cost = tensor.sum((x-W)**2)
    cost.name = "cost"

    algorithm = GradientDescent(cost=cost, parameters=[W],
                                step_rule=Scale(1e-3))

    main_loop = MainLoop(
        model=None, data_stream=dataset.get_example_stream(),
        algorithm=algorithm,
        extensions=[
            FinishAfter(after_n_epochs=1),
            extension])

    return main_loop
예제 #4
0
def test_shared_variable_modifier_two_params():
    weights = numpy.array([-1, 1], dtype=theano.config.floatX)
    features = [numpy.array(f, dtype=theano.config.floatX)
                for f in [[1, 2], [3, 4], [5, 6]]]
    targets = [(weights * f).sum() for f in features]
    n_batches = 3
    dataset = IterableDataset(dict(features=features, targets=targets))

    x = tensor.vector('features')
    y = tensor.scalar('targets')
    W = shared_floatx([0, 0], name='W')
    cost = ((x * W).sum() - y) ** 2
    cost.name = 'cost'

    step_rule = Scale(0.001)
    sgd = GradientDescent(cost=cost, params=[W],
                          step_rule=step_rule)
    modifier = SharedVariableModifier(
        step_rule.learning_rate,
        lambda _, val: numpy.cast[theano.config.floatX](val * 0.2))
    main_loop = MainLoop(
        model=None, data_stream=dataset.get_example_stream(),
        algorithm=sgd,
        extensions=[FinishAfter(after_n_epochs=1), modifier])

    main_loop.run()

    new_value = step_rule.learning_rate.get_value()
    assert_allclose(new_value,
                    0.001 * 0.2 ** n_batches,
                    atol=1e-5)
예제 #5
0
def setup_mainloop(extension):
    """Set up a simple main loop for progress bar tests.

    Create a MainLoop, register the given extension, supply it with a
    DataStream and a minimal model/cost to optimize.

    """
    # Since progressbar2 3.6.0, the `maxval` kwarg has been replaced by
    # `max_value`, which has a default value of 100. If we're still using
    # `maxval` by accident, this test should fail complaining that
    # the progress bar has received a value out of range.
    features = [numpy.array(f, dtype=theano.config.floatX)
                for f in [[1, 2]] * 101]
    dataset = IterableDataset(dict(features=features))

    W = shared_floatx([0, 0], name='W')
    x = tensor.vector('features')
    cost = tensor.sum((x-W)**2)
    cost.name = "cost"

    algorithm = GradientDescent(cost=cost, parameters=[W],
                                step_rule=Scale(1e-3))

    main_loop = MainLoop(
        model=None, data_stream=dataset.get_example_stream(),
        algorithm=algorithm,
        extensions=[
            FinishAfter(after_n_epochs=1),
            extension])

    return main_loop
예제 #6
0
def get_data_stream(iterable):
    """Returns a 'fuel.Batch' datastream of
    [x~input~numbers, y~targets~roots], with each iteration returning a
    batch of 20 training examples
    """
    dataset = IterableDataset({"numbers": iterable})
    data_stream = Mapping(dataset.get_example_stream(), _data_sqrt, add_sources=("roots",))
    data_stream = Mapping(data_stream, _array_tuple)
    return Batch(data_stream, ConstantScheme(20))
예제 #7
0
def test_merge():
    english = IterableDataset(['Hello world!'])
    french = IterableDataset(['Bonjour le monde!'])
    streams = (english.get_example_stream(),
               french.get_example_stream())
    merged_stream = Merge(streams, ('english', 'french'))
    assert merged_stream.sources == ('english', 'french')
    assert (next(merged_stream.get_epoch_iterator()) ==
            ('Hello world!', 'Bonjour le monde!'))
예제 #8
0
def get_data_stream(iterable):
    """Returns a 'fuel.Batch' datastream of
    [x~input~numbers, y~targets~roots], with each iteration returning a
    batch of 20 training examples
    """
    numbers = numpy.asarray(iterable, dtype=floatX)
    dataset = IterableDataset(
        {'numbers': numbers, 'roots': numpy.sqrt(numbers)})
    return Batch(dataset.get_example_stream(), ConstantScheme(20))
예제 #9
0
def test_training_data_monitoring():
    weights = numpy.array([-1, 1], dtype=theano.config.floatX)
    features = [numpy.array(f, dtype=theano.config.floatX)
                for f in [[1, 2], [3, 4], [5, 6]]]
    targets = [(weights * f).sum() for f in features]
    n_batches = 3
    dataset = IterableDataset(dict(features=features, targets=targets))

    x = tensor.vector('features')
    y = tensor.scalar('targets')
    W = shared_floatx([0, 0], name='W')
    V = shared_floatx(7, name='V')
    W_sum = named_copy(W.sum(), 'W_sum')
    cost = ((x * W).sum() - y) ** 2
    cost.name = 'cost'

    class TrueCostExtension(TrainingExtension):

        def before_batch(self, data):
            self.main_loop.log.current_row['true_cost'] = (
                ((W.get_value() * data["features"]).sum() -
                 data["targets"]) ** 2)

    main_loop = MainLoop(
        model=None, data_stream=dataset.get_example_stream(),
        algorithm=GradientDescent(cost=cost, params=[W],
                                  step_rule=Scale(0.001)),
        extensions=[
            FinishAfter(after_n_epochs=1),
            TrainingDataMonitoring([W_sum, cost, V], prefix="train1",
                                   after_batch=True),
            TrainingDataMonitoring([aggregation.mean(W_sum), cost],
                                   prefix="train2", after_epoch=True),
            TrueCostExtension()])

    main_loop.run()

    # Check monitoring of a shared varible
    assert_allclose(main_loop.log.current_row['train1_V'], 7.0)

    for i in range(n_batches):
        # The ground truth is written to the log before the batch is
        # processed, where as the extension writes after the batch is
        # processed. This is why the iteration numbers differs here.
        assert_allclose(main_loop.log[i]['true_cost'],
                        main_loop.log[i + 1]['train1_cost'])
    assert_allclose(
        main_loop.log[n_batches]['train2_cost'],
        sum([main_loop.log[i]['true_cost']
             for i in range(n_batches)]) / n_batches)
    assert_allclose(
        main_loop.log[n_batches]['train2_W_sum'],
        sum([main_loop.log[i]['train1_W_sum']
             for i in range(1, n_batches + 1)]) / n_batches)
예제 #10
0
    def do_test(with_serialization):
        data_stream = IterableDataset(range(10)).get_example_stream()
        main_loop = MainLoop(MockAlgorithm(),
                             data_stream,
                             extensions=[
                                 WriteBatchExtension(),
                                 FinishAfter(after_n_batches=14)
                             ])
        main_loop.run()
        assert main_loop.log.status['iterations_done'] == 14

        if with_serialization:
            main_loop = cPickle.loads(cPickle.dumps(main_loop))

        finish_after = unpack([
            ext
            for ext in main_loop.extensions if isinstance(ext, FinishAfter)
        ],
                              singleton=True)
        finish_after.add_condition(
            ["after_batch"],
            predicate=lambda log: log.status['iterations_done'] == 27)
        main_loop.run()
        assert main_loop.log.status['iterations_done'] == 27
        assert main_loop.log.status['epochs_done'] == 2
        for i in range(27):
            assert main_loop.log[i + 1]['batch'] == {"data": i % 10}
 def test_filter_examples(self):
     data = [1, 2, 3]
     data_filtered = [1, 3]
     stream = DataStream(IterableDataset(data))
     wrapper = Filter(stream, lambda d: d[0] % 2 == 1)
     assert_equal(list(wrapper.get_epoch_iterator()),
                  list(zip(data_filtered)))
예제 #12
0
def setup_mainloop(extension, iteration_scheme=None):
    """Set up a simple main loop for progress bar tests.

    Create a MainLoop, register the given extension, supply it with a
    DataStream and a minimal model/cost to optimize.

    """
    # Since progressbar2 3.6.0, the `maxval` kwarg has been replaced by
    # `max_value`, which has a default value of 100. If we're still using
    # `maxval` by accident, this test should fail complaining that
    # the progress bar has received a value out of range.
    features = [
        numpy.array(f, dtype=theano.config.floatX) for f in [[1, 2]] * 101
    ]
    dataset = IterableDataset(dict(features=features))
    data_stream = DataStream(dataset, iteration_scheme=iteration_scheme)

    W = shared_floatx([0, 0], name='W')
    x = tensor.vector('features')
    cost = tensor.sum((x - W)**2)
    cost.name = "cost"

    algorithm = GradientDescent(cost=cost,
                                parameters=[W],
                                step_rule=Scale(1e-3))

    main_loop = MainLoop(model=None,
                         data_stream=data_stream,
                         algorithm=algorithm,
                         extensions=[FinishAfter(after_n_epochs=1), extension])

    return main_loop
 def test_add_sources(self):
     stream = DataStream(IterableDataset(self.data))
     transformer = Mapping(stream, lambda d: ([2 * i for i in d[0]],),
                           add_sources=('doubled',))
     assert_equal(transformer.sources, ('data', 'doubled'))
     assert_equal(list(transformer.get_epoch_iterator()),
                  list(zip(self.data, [[2, 4, 6], [4, 6, 2], [6, 4, 2]])))
 def test_value_error_on_request(self):
     transformer = Padding(Batch(
         DataStream(
             IterableDataset(
                 dict(features=[[1], [2, 3]], targets=[[4, 5, 6], [7]]))),
         ConstantScheme(2)))
     assert_raises(ValueError, transformer.get_data, [0, 1])
 def test_two_sources(self):
     transformer = Padding(Batch(
         DataStream(
             IterableDataset(
                 dict(features=[[1], [2, 3]], targets=[[4, 5, 6], [7]]))),
         ConstantScheme(2)))
     assert len(next(transformer.get_epoch_iterator())) == 4
 def setUp(self):
     dataset = IterableDataset(
         OrderedDict([('features', [1, 2, 3]), ('targets', [0, 1, 0])]),
         axis_labels={'features': ('batch'), 'targets': ('batch')})
     self.stream = DataStream(dataset)
     self.wrapper = ScaleAndShift(
         self.stream, 2, -1, which_sources=('targets',))
 def test_adds_batch_to_axis_labels(self):
     stream = DataStream(
         IterableDataset(
             {'features': [1, 2, 3, 4, 5]},
             axis_labels={'features': ('index',)}))
     transformer = Batch(stream, ConstantScheme(2), strictness=0)
     assert_equal(transformer.axis_labels, {'features': ('batch', 'index')})
예제 #18
0
def test_scale_and_shift():
    stream = DataStream(
        IterableDataset({
            'features': [1, 2, 3],
            'targets': [0, 1, 0]
        }))
    wrapper = ScaleAndShift(stream, 2, -1, which_sources=('targets', ))
    assert list(wrapper.get_epoch_iterator()) == [(1, -1), (2, 1), (3, -1)]
예제 #19
0
def test_ngram_stream_raises_error_on_request():
    sentences = [
        list(numpy.random.randint(10, size=sentence_length))
        for sentence_length in [3, 5, 7]
    ]
    stream = DataStream(IterableDataset(sentences))
    ngrams = NGrams(4, stream)
    assert_raises(ValueError, ngrams.get_data, [0, 1])
예제 #20
0
def test_num_examples():
    assert_raises(ValueError, IterableDataset, {
        'features': range(10),
        'targets': range(7)
    })
    dataset = IterableDataset({'features': range(7), 'targets': range(7)})
    assert dataset.num_examples == 7
    dataset = IterableDataset(repeat(1))
    assert numpy.isnan(dataset.num_examples)
    x = numpy.random.rand(5, 3)
    y = numpy.random.rand(5, 4)
    dataset = IndexableDataset({'features': x, 'targets': y})
    assert dataset.num_examples == 5
    assert_raises(ValueError, IndexableDataset, {
        'features': x,
        'targets': y[:4]
    })
예제 #21
0
def test_force_floatx():
    x = [numpy.array(d, dtype="float64") for d in [[1, 2], [3, 4], [5, 6]]]
    y = [numpy.array(d, dtype="int64") for d in [1, 2, 3]]
    dataset = IterableDataset(OrderedDict([("x", x), ("y", y)]))
    wrapper = ForceFloatX(DataStream(dataset))
    data = next(wrapper.get_epoch_iterator())
    assert str(data[0].dtype) == config.floatX
    assert str(data[1].dtype) == "int64"
예제 #22
0
    def test_mapping_dict(self):
        def mapping(d):
            return {'data': [2 * i for i in d['data']]}

        stream = DataStream(IterableDataset(self.data))
        transformer = Mapping(stream, mapping, mapping_accepts=dict)
        assert_equal(list(transformer.get_epoch_iterator()),
                     list(zip([[2, 4, 6], [4, 6, 2], [6, 4, 2]])))
예제 #23
0
 def test_mapping_sort_multisource(self):
     data = OrderedDict([('x', self.data_x), ('y', self.data_y)])
     data_sorted = [([1, 2, 3], [6, 5, 4]), ([1, 2, 3], [4, 6, 5]),
                    ([1, 2, 3], [4, 5, 6])]
     stream = DataStream(IterableDataset(data))
     transformer = Mapping(stream,
                           mapping=SortMapping(operator.itemgetter(0)))
     assert_equal(list(transformer.get_epoch_iterator()), data_sorted)
예제 #24
0
def test_ngram_stream():
    sentences = [
        list(numpy.random.randint(10, size=sentence_length))
        for sentence_length in [3, 5, 7]
    ]
    stream = IterableDataset(sentences).get_example_stream()
    ngrams = NGrams(4, stream)
    assert len(list(ngrams.get_epoch_iterator())) == 4
예제 #25
0
def test_iterable_dataset():
    from fuel.datasets import IterableDataset

    seed = 1234
    rng = numpy.random.RandomState(seed)
    features = rng.randint(256, size=(8, 2, 2))
    targets = rng.randint(4, size=(8, 1))

    dataset = IterableDataset(iterables=OrderedDict([('features', features), ('targets', targets)]),
                              axis_labels=OrderedDict([('features', ('batch', 'height', 'width')),
                                                       ('targets', ('batch', 'index'))]))

    print('Provided sources are {}.'.format(dataset.provides_sources))
    print('Sources are {}.'.format(dataset.sources))
    print('Axis labels are {}.'.format(dataset.axis_labels))
    print('Dataset contains {} examples.'.format(dataset.num_examples))

    state = dataset.open()
    while True:
        try:
            print(dataset.get_data(state=state))
        except StopIteration:
            print('Iteration over')
            break

    state = dataset.reset(state=state)
    print(dataset.get_data(state=state))

    dataset.close(state=state)
예제 #26
0
def test_mapping_sort_multisource():
    data = OrderedDict()
    data['x'] = [[1, 2, 3], [2, 3, 1], [3, 2, 1]]
    data['y'] = [[6, 5, 4], [6, 5, 4], [6, 5, 4]]
    data_sorted = [([1, 2, 3], [6, 5, 4]), ([1, 2, 3], [4, 6, 5]),
                   ([1, 2, 3], [4, 5, 6])]
    stream = DataStream(IterableDataset(data))
    wrapper = Mapping(stream, mapping=SortMapping(operator.itemgetter(0)))
    assert list(wrapper.get_epoch_iterator()) == data_sorted
예제 #27
0
 def setUp(self):
     dataset = IterableDataset(
         OrderedDict([
             ('features', numpy.array([1, 2, 3]).astype('float64')),
             ('targets', [0, 1, 0])]),
         axis_labels={'features': ('batch'), 'targets': ('batch')})
     self.stream = DataStream(dataset)
     self.wrapper = Cast(
         self.stream, 'float32', which_sources=('features',))
예제 #28
0
def test_unpack_transformer():
    data = range(10)
    stream = DataStream(IterableDataset(data))
    stream = Batch(stream, iteration_scheme=ConstantScheme(2))
    stream = Unpack(stream)
    epoch = stream.get_epoch_iterator()
    for i, v in enumerate(epoch):
        assert numpy.shape(v)[0] == 1
        assert v[0] == i
예제 #29
0
 def test_mask_dtype(self):
     transformer = Padding(Batch(
         DataStream(
             IterableDataset(
                 dict(features=[[1], [2, 3]], targets=[[4, 5, 6], [7]]))),
         ConstantScheme(2)),
         mask_dtype='uint8')
     assert_equal(
         str(next(transformer.get_epoch_iterator())[1].dtype), 'uint8')
예제 #30
0
 def test_mask_sources(self):
     transformer = Padding(Batch(
         DataStream(
             IterableDataset(
                 OrderedDict([('features', [[1], [2, 3]]),
                              ('targets', [[4, 5, 6], [7]])]))),
         ConstantScheme(2)),
         mask_sources=('features',))
     assert_equal(len(next(transformer.get_epoch_iterator())), 3)
예제 #31
0
def test_ngram_stream_error_on_multiple_sources():
    # Check that NGram accepts only data streams with one source
    sentences = [
        list(numpy.random.randint(10, size=sentence_length))
        for sentence_length in [3, 5, 7]
    ]
    stream = DataStream(IterableDataset(sentences))
    stream.sources = ('1', '2')
    assert_raises(ValueError, NGrams, 4, stream)
예제 #32
0
 def test_2d_sequences(self):
     stream = Batch(
         DataStream(
             IterableDataset([numpy.ones((3, 4)), 2 * numpy.ones((2, 4))])),
         ConstantScheme(2))
     it = Padding(stream).get_epoch_iterator()
     data, mask = next(it)
     assert data.shape == (2, 3, 4)
     assert (data[0, :, :] == 1).all()
     assert (data[1, :2, :] == 2).all()
     assert (mask == numpy.array([[1, 1, 1], [1, 1, 0]])).all()
예제 #33
0
def test_cast():
    stream = DataStream(
        IterableDataset(
            OrderedDict([
                ('features', numpy.array([1, 2, 3]).astype('float64')),
                ('targets', [0, 1, 0])])))
    wrapper = Cast(stream, 'float32', which_sources=('features',))
    assert_equal(
        list(wrapper.get_epoch_iterator()),
        [(numpy.array(1), 0), (numpy.array(2), 1), (numpy.array(3), 0)])
    assert all(f.dtype == 'float32' for f, t in wrapper.get_epoch_iterator())
예제 #34
0
def test_mapping():
    data = [1, 2, 3]
    data_doubled = [2, 4, 6]
    stream = DataStream(IterableDataset(data))
    wrapper1 = Mapping(stream, lambda d: (2 * d[0], ))
    assert list(wrapper1.get_epoch_iterator()) == list(zip(data_doubled))
    wrapper2 = Mapping(stream,
                       lambda d: (2 * d[0], ),
                       add_sources=("doubled", ))
    assert wrapper2.sources == ("data", "doubled")
    assert list(wrapper2.get_epoch_iterator()) == list(zip(data, data_doubled))
예제 #35
0
 def test_1d_sequences(self):
     stream = Batch(
         DataStream(IterableDataset([[1], [2, 3], [], [4, 5, 6], [7]])),
         ConstantScheme(2))
     transformer = Padding(stream)
     assert_equal(transformer.sources, ("data", "data_mask"))
     assert_equal(list(transformer.get_epoch_iterator()),
                  [(numpy.array([[1, 0], [2, 3]]),
                    numpy.array([[1, 0], [1, 1]])),
                   (numpy.array([[0, 0, 0], [4, 5, 6]]),
                    numpy.array([[0, 0, 0], [1, 1, 1]])),
                   (numpy.array([[7]]), numpy.array([[1]]))])
예제 #36
0
def test_training_data_monitoring_updates_algorithm():
    features = [numpy.array(f, dtype=theano.config.floatX)
                for f in [[1, 2], [3, 5], [5, 8]]]
    targets = numpy.array([f.sum() for f in features])
    dataset = IterableDataset(dict(features=features, targets=targets))

    x = tensor.vector('features')
    y = tensor.scalar('targets')
    m = x.mean().copy(name='features_mean')
    t = y.sum().copy(name='targets_sum')

    main_loop = MainLoop(
        model=None, data_stream=dataset.get_example_stream(),
        algorithm=UpdatesAlgorithm(),
        extensions=[TrainingDataMonitoring([m, t], prefix="train1",
                                           after_batch=True)],
    )
    main_loop.extensions[0].main_loop = main_loop
    assert len(main_loop.algorithm.updates) == 0
    main_loop.extensions[0].do('before_training')
    assert len(main_loop.algorithm.updates) > 0
예제 #37
0
def test_batch_data_stream():
    stream = DataStream(IterableDataset([1, 2, 3, 4, 5]))
    batches = list(Batch(stream, ConstantScheme(2))
                   .get_epoch_iterator())
    expected = [(numpy.array([1, 2]),),
                (numpy.array([3, 4]),),
                (numpy.array([5]),)]
    assert len(batches) == len(expected)
    for b, e in zip(batches, expected):
        assert (b[0] == e[0]).all()

    # Check the `strict` flag
    def try_strict(strictness):
        return list(
            Batch(stream, ConstantScheme(2), strictness=strictness)
            .get_epoch_iterator())
    assert_raises(ValueError, try_strict, 2)
    assert len(try_strict(1)) == 2
    stream2 = DataStream(IterableDataset([1, 2, 3, 4, 5, 6]))
    assert len(list(Batch(stream2, ConstantScheme(2), strictness=2)
                    .get_epoch_iterator())) == 3
예제 #38
0
def test_mapping_sort():
    data = [[1, 2, 3], [2, 3, 1], [3, 2, 1]]
    data_sorted = [[1, 2, 3]] * 3
    data_sorted_rev = [[3, 2, 1]] * 3
    stream = DataStream(IterableDataset(data))
    wrapper1 = Mapping(stream, SortMapping(operator.itemgetter(0)))
    assert list(wrapper1.get_epoch_iterator()) == list(zip(data_sorted))
    wrapper2 = Mapping(stream, SortMapping(lambda x: -x[0]))
    assert list(wrapper2.get_epoch_iterator()) == list(zip(data_sorted_rev))
    wrapper3 = Mapping(stream, SortMapping(operator.itemgetter(0),
                                           reverse=True))
    assert list(wrapper3.get_epoch_iterator()) == list(zip(data_sorted_rev))
예제 #39
0
def test_dataset():
    data = [1, 2, 3]
    # The default stream requests an example at a time
    stream = DataStream(IterableDataset(data))
    epoch = stream.get_epoch_iterator()
    assert list(epoch) == list(zip(data))

    # Check if iterating over multiple epochs works
    for i, epoch in zip(range(2), stream.iterate_epochs()):
        assert list(epoch) == list(zip(data))

    # Check whether the returning as a dictionary of sources works
    assert next(stream.get_epoch_iterator(as_dict=True)) == {"data": 1}
예제 #40
0
    def test_epoch_finishes_correctly(self):
        cached_stream = Cache(self.stream, ConstantScheme(7))
        data = list(cached_stream.get_epoch_iterator())
        assert_equal(len(data[-1][0]), 100 % 7)
        assert not cached_stream.cache[0]

        stream = Batch(DataStream(IterableDataset(range(3000))),
                       ConstantScheme(3200))

        cached_stream = Cache(stream, ConstantScheme(64))
        data = list(cached_stream.get_epoch_iterator())
        assert_equal(len(data[-1][0]), 3000 % 64)
        assert not cached_stream.cache[0]
예제 #41
0
def test_merge():
    english = IterableDataset(['Hello world!'])
    french = IterableDataset(['Bonjour le monde!'])
    streams = (english.get_example_stream(), french.get_example_stream())
    merged_stream = Merge(streams, ('english', 'french'))
    assert merged_stream.sources == ('english', 'french')
    assert (next(merged_stream.get_epoch_iterator()) == ('Hello world!',
                                                         'Bonjour le monde!'))
예제 #42
0
파일: sqrt.py 프로젝트: basaundi/blocks
def get_data_stream(iterable):
    dataset = IterableDataset({'numbers': iterable})
    data_stream = Mapping(dataset.get_example_stream(),
                          _data_sqrt, add_sources=('roots',))
    data_stream = Mapping(data_stream, _array_tuple)
    return Batch(data_stream, ConstantScheme(20))
예제 #43
0
파일: test_datasets.py 프로젝트: Afrik/fuel
 def test_filter_sources(self):
     dataset = IterableDataset(
         OrderedDict([('1', [1, 2]), ('2', [3, 4])]), sources=('1',))
     assert_equal(dataset.filter_sources(([1, 2], [3, 4])), ([1, 2],))