def test_datastream_merge():

    datastream = Datastream.merge([
        Datastream(Dataset.from_subscriptable(list('abc'))),
        Datastream(Dataset.from_subscriptable(list('def'))),
    ])

    it = iter(datastream.sampler)
    for _ in range(2):
        index = next(it)

    it = iter(datastream.data_loader(batch_size=8))
    for _ in range(10):
        batch = next(it)
Example #2
0
def test_last_batch():
    from datastream.samplers import SequentialSampler

    datastream = Datastream(Dataset.from_subscriptable(list('abc')))
    assert list(map(len, datastream.data_loader(batch_size=4))) == [3]
    assert list(
        map(len, datastream.data_loader(batch_size=4,
                                        n_batches_per_epoch=2))) == [4, 4]

    datastream = Datastream(
        Dataset.from_subscriptable(list('abc')),
        SequentialSampler(3),
    )
    assert list(map(len, datastream.data_loader(batch_size=2))) == [2, 1]
Example #3
0
def test_take():

    import pytest

    datastream = Datastream(Dataset.from_subscriptable(list('abc'))).take(2)
    assert len(list(datastream.data_loader(batch_size=1))) == 2

    with pytest.raises(ValueError):
        Datastream(Dataset.from_subscriptable(list('abc'))).take(0)

    datastream = Datastream.merge([
        Datastream(Dataset.from_subscriptable(list('abc'))),
        Datastream(Dataset.from_subscriptable(list('d'))),
    ])
    assert len(list(datastream.take(2).data_loader(batch_size=1))) == 2
Example #4
0
def test_multi_sample():

    data = [1, 2, 4]
    n_multi_sample = 2

    datastream = (
        Datastream(
            Dataset.from_subscriptable(data)).map(lambda number: number**2).
        multi_sample(n_multi_sample).sample_proportion(0.5).zip_index(
        ).starmap(lambda number, index: (number**0.5, index)))

    output = [(number, index)
              for number, index in datastream.data_loader(batch_size=1)]
    assert len(output) == len(data) * n_multi_sample
    print(output)

    state = datastream.state_dict()
    datastream.load_state_dict(state)

    for index, number in zip(output, range(2)):
        datastream.update_example_weight_(index, 0)

    output2 = [(number, index)
               for number, index in datastream.data_loader(batch_size=1)]
    assert len(output2) == len(data) * n_multi_sample

    zero_indices = set([index for _, index in output[:2]])
    for number, index in output2:
        assert index not in zero_indices
Example #5
0
def test_datastream_merge():

    datastream = Datastream.merge([
        Datastream(Dataset.from_subscriptable(list('abc'))),
        Datastream(Dataset.from_subscriptable(list('def'))),
    ])

    it = iter(datastream.sampler)
    for _ in range(2):
        index = next(it)

    it = iter(datastream.data_loader(batch_size=8, n_batches_per_epoch=10))
    for _ in range(10):
        batch = next(it)

    assert (len(list(datastream.data_loader(batch_size=1))) == len(datastream))
Example #6
0
def test_concat_merge():
    dataset = Dataset.concat([
        Dataset.from_subscriptable([1, 2]),
        Dataset.from_subscriptable([1, 3, 5]),
    ])

    datastream = Datastream.merge([
        Datastream(dataset),
        Datastream(
            dataset.subset(
                lambda df: [index < 3 for index in range(len(df))])),
    ])

    assert len(
        dataset.subset(
            lambda df: [index < 3 for index in range(len(df))])) == 3

    assert len(list(datastream)) == 6
def test_datastream_zip():

    datasets = [
        Dataset.from_subscriptable([1, 2]),
        Dataset.from_subscriptable([3, 4, 5]),
        Dataset.from_subscriptable([6, 7]),
    ]

    datastreams = [
        Datastream(ds, sampler=torch.utils.data.SequentialSampler(ds))
        for ds in datasets
    ]
    zipped_datastream = Datastream.zip(datastreams)

    batch = next(iter(zipped_datastream.data_loader(batch_size=3)))
    assert len(batch) == 3 and len(batch[0]) == 3
    assert batch[0][0] == 1 and batch[0][1] == 2 and batch[0][2] == 1
    assert batch[1][0] == 3 and batch[1][1] == 4 and batch[1][2] == 5
    assert batch[2][0] == 6 and batch[2][1] == 7 and batch[2][2] == 6
Example #8
0
def test_sequential_sampler():

    from datastream.samplers import SequentialSampler

    dataset = Dataset.from_subscriptable(list('abc'))
    datastream = Datastream(dataset, SequentialSampler(len(dataset))).take(2)
    assert len(list(datastream.data_loader(batch_size=1))) == 2

    datastream = Datastream(dataset, SequentialSampler(len(dataset)))
    it = iter(datastream.data_loader(batch_size=6, n_batches_per_epoch=10))
    assert next(it) == ['a', 'b', 'c', 'a', 'b', 'c']
Example #9
0
def test_merge_datastream_weights():

    datasets = [
        Dataset.from_subscriptable([1, 2]),
        Dataset.from_subscriptable([3, 4, 5]),
        Dataset.from_subscriptable([6, 7]),
    ]

    datastream = (Datastream.merge([
        Datastream(dataset) for dataset in datasets
    ]).zip_index().starmap(lambda integer, index: dict(
        integer=integer,
        index=index,
    )).sample_proportion(0.5))

    removed_indices = [0, 3]
    for index in removed_indices:
        datastream.update_example_weight_(0.0, index)

    samples = list(datastream.data_loader(batch_size=4, n_batches_per_epoch=4))

    datastream.update_weights_(lambda weights: weights * 0.9 + 1 * 0.1)
Example #10
0
def test_datastream_simple_weights():

    dataset = Dataset.from_subscriptable([1, 2, 3, 4])
    datastream = (
        Datastream(dataset).zip_index().starmap(lambda integer, index: dict(
            integer=integer,
            index=index,
        )).sample_proportion(0.5))

    removed_indices = [0, 3]
    for index in removed_indices:
        datastream.update_example_weight_(0.0, removed_indices)

    samples = list(datastream.data_loader(batch_size=1))

    assert len(samples) == 2

    for sample in samples:
        if sample['index'] in removed_indices:
            raise AssertionError(
                'Samples with 0 weight were drawn from the dataset')
Example #11
0
def test_combine_concat_merge():
    dataset = Dataset.concat([
        Dataset.zip([
            Dataset.from_subscriptable([1]),
            Dataset.from_subscriptable([2]),
        ]),
        Dataset.combine([
            Dataset.from_subscriptable([3, 3]),
            Dataset.from_subscriptable([4, 4, 4]),
        ]),
    ])

    datastream = Datastream.merge([
        Datastream(dataset),
        Datastream(
            Dataset.zip([
                Dataset.from_subscriptable([5]),
                Dataset.from_subscriptable([6]),
            ])),
    ])

    assert len(list(datastream)) == 2
Example #12
0
 def RandomDatastream():
     return Datastream(
         Dataset.from_subscriptable(list(range(np.random.randint(1, 10)))))
Example #13
0
def test_empty():

    import pytest

    with pytest.raises(ValueError):
        Datastream(Dataset.from_subscriptable(list()))
Example #14
0
def test_iter():

    datastream = Datastream(Dataset.from_subscriptable(list('abc')))
    assert len(list(datastream)) == 3
Example #15
0
def test_infinite():

    datastream = Datastream(Dataset.from_subscriptable(list('abc')))
    it = iter(datastream.data_loader(batch_size=8, n_batches_per_epoch=10))
    for _ in range(10):
        batch = next(it)