Ejemplo n.º 1
0
def test_last_batch():
    from datastream.samplers import SequentialSampler

    datastream = Datastream(Dataset.from_subscriptable(list('abc')))
    assert list(map(len, datastream.data_loader(batch_size=4))) == [3]
    assert list(
        map(len, datastream.data_loader(batch_size=4,
                                        n_batches_per_epoch=2))) == [4, 4]

    datastream = Datastream(
        Dataset.from_subscriptable(list('abc')),
        SequentialSampler(3),
    )
    assert list(map(len, datastream.data_loader(batch_size=2))) == [2, 1]
Ejemplo n.º 2
0
def test_datastream_merge():

    datastream = Datastream.merge([
        Datastream(Dataset.from_subscriptable(list('abc'))),
        Datastream(Dataset.from_subscriptable(list('def'))),
    ])

    it = iter(datastream.sampler)
    for _ in range(2):
        index = next(it)

    it = iter(datastream.data_loader(batch_size=8))
    for _ in range(10):
        batch = next(it)
Ejemplo n.º 3
0
def test_take():

    import pytest

    datastream = Datastream(Dataset.from_subscriptable(list('abc'))).take(2)
    assert len(list(datastream.data_loader(batch_size=1))) == 2

    with pytest.raises(ValueError):
        Datastream(Dataset.from_subscriptable(list('abc'))).take(0)

    datastream = Datastream.merge([
        Datastream(Dataset.from_subscriptable(list('abc'))),
        Datastream(Dataset.from_subscriptable(list('d'))),
    ])
    assert len(list(datastream.take(2).data_loader(batch_size=1))) == 2
Ejemplo n.º 4
0
def test_multi_sample():

    data = [1, 2, 4]
    n_multi_sample = 2

    datastream = (
        Datastream(
            Dataset.from_subscriptable(data)).map(lambda number: number**2).
        multi_sample(n_multi_sample).sample_proportion(0.5).zip_index(
        ).starmap(lambda number, index: (number**0.5, index)))

    output = [(number, index)
              for number, index in datastream.data_loader(batch_size=1)]
    assert len(output) == len(data) * n_multi_sample
    print(output)

    state = datastream.state_dict()
    datastream.load_state_dict(state)

    for index, number in zip(output, range(2)):
        datastream.update_example_weight_(index, 0)

    output2 = [(number, index)
               for number, index in datastream.data_loader(batch_size=1)]
    assert len(output2) == len(data) * n_multi_sample

    zero_indices = set([index for _, index in output[:2]])
    for number, index in output2:
        assert index not in zero_indices
Ejemplo n.º 5
0
    def merge(
        datastreams_and_ns: Tuple[Union[Datastream[T], Tuple[Datastream[T],
                                                             int]], ...]
    ) -> Datastream[T]:
        '''
        Merge multiple datastreams by interleaving them. Optionally you can
        define different lengths per ``Datastream``.

        .. highlight:: python
        .. code-block:: python

            Datastream.merge([
                (datastream1, 2),
                (datastream2, 1),
                (datastream3, 1),
            ])
        '''
        datastreams_and_ns = [
            x if type(x) is tuple else (x, 1) for x in datastreams_and_ns
        ]

        return Datastream(
            Dataset.concat(
                [datastream.dataset for datastream, n in datastreams_and_ns]),
            MergeSampler(*zip(*[(datastream.sampler, datastream.dataset, n)
                                for (datastream, n) in datastreams_and_ns])),
        )
Ejemplo n.º 6
0
def test_datastream_merge():

    datastream = Datastream.merge([
        Datastream(Dataset.from_subscriptable(list('abc'))),
        Datastream(Dataset.from_subscriptable(list('def'))),
    ])

    it = iter(datastream.sampler)
    for _ in range(2):
        index = next(it)

    it = iter(datastream.data_loader(batch_size=8, n_batches_per_epoch=10))
    for _ in range(10):
        batch = next(it)

    assert (len(list(datastream.data_loader(batch_size=1))) == len(datastream))
Ejemplo n.º 7
0
    def merge(
        datastreams_and_ns: Tuple[Union[Datastream[T], Tuple[Datastream[T],
                                                             int]], ...]
    ) -> Datastream[T]:
        '''
        Creates a merged datastream where samples are drawn one at a time from
        each underlying datastream (also known as "interleave").

        Optionally you can define the number of drawn samples per
        ``Datastream``.

        >>> datastream1 = Datastream(Dataset.from_subscriptable([1, 1]))
        >>> datastream2 = Datastream(Dataset.from_subscriptable([2, 2]))
        >>> datastream3 = Datastream(Dataset.from_subscriptable([3, 3, 3, 3]))
        >>> merged_datastream = Datastream.merge([
        ...     (datastream1, 1),
        ...     (datastream2, 1),
        ...     (datastream3, 2),
        ... ])
        >>> list(merged_datastream)
        [1, 2, 3, 3, 1, 2, 3, 3]
        '''
        datastreams_and_ns = [
            x if type(x) is tuple else (x, 1) for x in datastreams_and_ns
        ]

        return Datastream(
            Dataset.concat(
                [datastream.dataset for datastream, n in datastreams_and_ns]),
            MergeSampler(*zip(*[(datastream.sampler, datastream.dataset, n)
                                for (datastream, n) in datastreams_and_ns])),
        )
Ejemplo n.º 8
0
def MnistDataset(dataframe):
    return (Dataset.from_dataframe(dataframe).map(lambda row: (
        Path(row["image_path"]),
        row["class_name"],
    )).starmap(lambda image_path, class_name: problem.Example(
        image=Image.open("prepare" / image_path).resize((32, 32)),
        class_name=class_name,
    )))
Ejemplo n.º 9
0
 def __init__(self, samplers, datasets):
     BaseModel.__init__(
         self,
         samplers=samplers,
         datasets=datasets,
         length=max(map(len, samplers)),
         from_mapping=Dataset.create_from_combine_mapping(datasets),
         zipped_samplers=ZipSampler.zip_samplers(samplers, datasets),
     )
Ejemplo n.º 10
0
def test_concat_merge():
    dataset = Dataset.concat([
        Dataset.from_subscriptable([1, 2]),
        Dataset.from_subscriptable([1, 3, 5]),
    ])

    datastream = Datastream.merge([
        Datastream(dataset),
        Datastream(
            dataset.subset(
                lambda df: [index < 3 for index in range(len(df))])),
    ])

    assert len(
        dataset.subset(
            lambda df: [index < 3 for index in range(len(df))])) == 3

    assert len(list(datastream)) == 6
Ejemplo n.º 11
0
def test_datastream_zip():

    datasets = [
        Dataset.from_subscriptable([1, 2]),
        Dataset.from_subscriptable([3, 4, 5]),
        Dataset.from_subscriptable([6, 7]),
    ]

    datastreams = [
        Datastream(ds, sampler=torch.utils.data.SequentialSampler(ds))
        for ds in datasets
    ]
    zipped_datastream = Datastream.zip(datastreams)

    batch = next(iter(zipped_datastream.data_loader(batch_size=3)))
    assert len(batch) == 3 and len(batch[0]) == 3
    assert batch[0][0] == 1 and batch[0][1] == 2 and batch[0][2] == 1
    assert batch[1][0] == 3 and batch[1][1] == 4 and batch[1][2] == 5
    assert batch[2][0] == 6 and batch[2][1] == 7 and batch[2][2] == 6
Ejemplo n.º 12
0
    def zip_samplers(samplers, datasets):
        to_mapping = Dataset.create_to_combine_mapping(datasets)

        create_sampler = starcompose(
            partial(map, partial(repeat_map_chain, iter)),
            tuple,
            zip,
            partial(map, to_mapping),
        )
        return create_sampler(samplers)
Ejemplo n.º 13
0
 def __init__(self, samplers, datasets, ns):
     BaseModel.__init__(
         self,
         samplers=samplers,
         datasets=datasets,
         ns=ns,
         length=MergeSampler.merged_samplers_length(samplers),
         from_mapping=Dataset.create_from_concat_mapping(datasets),
         merged_samplers=MergeSampler.merge_samplers(
             samplers, datasets, ns),
     )
Ejemplo n.º 14
0
def test_sequential_sampler():

    from datastream.samplers import SequentialSampler

    dataset = Dataset.from_subscriptable(list('abc'))
    datastream = Datastream(dataset, SequentialSampler(len(dataset))).take(2)
    assert len(list(datastream.data_loader(batch_size=1))) == 2

    datastream = Datastream(dataset, SequentialSampler(len(dataset)))
    it = iter(datastream.data_loader(batch_size=6, n_batches_per_epoch=10))
    assert next(it) == ['a', 'b', 'c', 'a', 'b', 'c']
Ejemplo n.º 15
0
def MnistDataset(dataframe):
    return (
        Dataset.from_dataframe(dataframe)
        .map(lambda row: (
            Path(row['image_path']),
            row['class_name'],
        ))
        .starmap(lambda image_path, class_name: problem.Example(
            image=Image.open('prepare' / image_path),
            class_name=class_name,
        ))
    )
Ejemplo n.º 16
0
def test_merge_datastream_weights():

    datasets = [
        Dataset.from_subscriptable([1, 2]),
        Dataset.from_subscriptable([3, 4, 5]),
        Dataset.from_subscriptable([6, 7]),
    ]

    datastream = (Datastream.merge([
        Datastream(dataset) for dataset in datasets
    ]).zip_index().starmap(lambda integer, index: dict(
        integer=integer,
        index=index,
    )).sample_proportion(0.5))

    removed_indices = [0, 3]
    for index in removed_indices:
        datastream.update_example_weight_(0.0, index)

    samples = list(datastream.data_loader(batch_size=4, n_batches_per_epoch=4))

    datastream.update_weights_(lambda weights: weights * 0.9 + 1 * 0.1)
Ejemplo n.º 17
0
 def zip(datastreams: List[Datastream]) -> Datastream[Tuple]:
     '''
     Zip multiple datastreams together so that all combinations of examples
     are possible (i.e. the product) creating tuples like
     ``(example1, example2, ...)``. The samples are drawn independently
     from each underlying datastream.
     '''
     return Datastream(
         Dataset.combine([datastream.dataset
                          for datastream in datastreams]),
         ZipSampler(*zip(*[(datastream.sampler, datastream.dataset)
                           for datastream in datastreams])),
     )
Ejemplo n.º 18
0
def test_combine_concat_merge():
    dataset = Dataset.concat([
        Dataset.zip([
            Dataset.from_subscriptable([1]),
            Dataset.from_subscriptable([2]),
        ]),
        Dataset.combine([
            Dataset.from_subscriptable([3, 3]),
            Dataset.from_subscriptable([4, 4, 4]),
        ]),
    ])

    datastream = Datastream.merge([
        Datastream(dataset),
        Datastream(
            Dataset.zip([
                Dataset.from_subscriptable([5]),
                Dataset.from_subscriptable([6]),
            ])),
    ])

    assert len(list(datastream)) == 2
Ejemplo n.º 19
0
def CifarDataset(dataframe):
    return (
        Dataset.from_dataframe(dataframe)
        .map(
            lambda row: (
                Path(row["image_path"]),
                row["class_name"],
            )
        )
        .starmap(
            lambda image_path, class_name: problem.Example(
                image=np.array(Image.open("prepare" / image_path)),
                class_name=class_name,
            )
        )
    )
Ejemplo n.º 20
0
    def merge_samplers(samplers, datasets, ns):
        to_mapping = Dataset.create_to_concat_mapping(datasets)

        def batch(iterable, n):
            while True:
                yield [next(iterable) for _ in range(n)]

        index_batch = zip(*[
            batch(
                map(
                    partial(to_mapping, dataset_index),
                    repeat_map_chain(iter, sampler),
                ), n)
            for dataset_index, (sampler, n) in enumerate(zip(samplers, ns))
        ])

        return chain.from_iterable(chain.from_iterable(index_batch))
Ejemplo n.º 21
0
def test_datastream_simple_weights():

    dataset = Dataset.from_subscriptable([1, 2, 3, 4])
    datastream = (
        Datastream(dataset).zip_index().starmap(lambda integer, index: dict(
            integer=integer,
            index=index,
        )).sample_proportion(0.5))

    removed_indices = [0, 3]
    for index in removed_indices:
        datastream.update_example_weight_(0.0, removed_indices)

    samples = list(datastream.data_loader(batch_size=1))

    assert len(samples) == 2

    for sample in samples:
        if sample['index'] in removed_indices:
            raise AssertionError(
                'Samples with 0 weight were drawn from the dataset')
Ejemplo n.º 22
0
def test_infinite():

    datastream = Datastream(Dataset.from_subscriptable(list('abc')))
    it = iter(datastream.data_loader(batch_size=8, n_batches_per_epoch=10))
    for _ in range(10):
        batch = next(it)
Ejemplo n.º 23
0
def test_iter():

    datastream = Datastream(Dataset.from_subscriptable(list('abc')))
    assert len(list(datastream)) == 3
Ejemplo n.º 24
0
def test_empty():

    import pytest

    with pytest.raises(ValueError):
        Datastream(Dataset.from_subscriptable(list()))
Ejemplo n.º 25
0
 def RandomDatastream():
     return Datastream(
         Dataset.from_subscriptable(list(range(np.random.randint(1, 10)))))