def test_datastream_merge(): datastream = Datastream.merge([ Datastream(Dataset.from_subscriptable(list('abc'))), Datastream(Dataset.from_subscriptable(list('def'))), ]) it = iter(datastream.sampler) for _ in range(2): index = next(it) it = iter(datastream.data_loader(batch_size=8)) for _ in range(10): batch = next(it)
def test_last_batch(): from datastream.samplers import SequentialSampler datastream = Datastream(Dataset.from_subscriptable(list('abc'))) assert list(map(len, datastream.data_loader(batch_size=4))) == [3] assert list( map(len, datastream.data_loader(batch_size=4, n_batches_per_epoch=2))) == [4, 4] datastream = Datastream( Dataset.from_subscriptable(list('abc')), SequentialSampler(3), ) assert list(map(len, datastream.data_loader(batch_size=2))) == [2, 1]
def test_take(): import pytest datastream = Datastream(Dataset.from_subscriptable(list('abc'))).take(2) assert len(list(datastream.data_loader(batch_size=1))) == 2 with pytest.raises(ValueError): Datastream(Dataset.from_subscriptable(list('abc'))).take(0) datastream = Datastream.merge([ Datastream(Dataset.from_subscriptable(list('abc'))), Datastream(Dataset.from_subscriptable(list('d'))), ]) assert len(list(datastream.take(2).data_loader(batch_size=1))) == 2
def test_multi_sample(): data = [1, 2, 4] n_multi_sample = 2 datastream = ( Datastream( Dataset.from_subscriptable(data)).map(lambda number: number**2). multi_sample(n_multi_sample).sample_proportion(0.5).zip_index( ).starmap(lambda number, index: (number**0.5, index))) output = [(number, index) for number, index in datastream.data_loader(batch_size=1)] assert len(output) == len(data) * n_multi_sample print(output) state = datastream.state_dict() datastream.load_state_dict(state) for index, number in zip(output, range(2)): datastream.update_example_weight_(index, 0) output2 = [(number, index) for number, index in datastream.data_loader(batch_size=1)] assert len(output2) == len(data) * n_multi_sample zero_indices = set([index for _, index in output[:2]]) for number, index in output2: assert index not in zero_indices
def test_datastream_merge(): datastream = Datastream.merge([ Datastream(Dataset.from_subscriptable(list('abc'))), Datastream(Dataset.from_subscriptable(list('def'))), ]) it = iter(datastream.sampler) for _ in range(2): index = next(it) it = iter(datastream.data_loader(batch_size=8, n_batches_per_epoch=10)) for _ in range(10): batch = next(it) assert (len(list(datastream.data_loader(batch_size=1))) == len(datastream))
def test_concat_merge(): dataset = Dataset.concat([ Dataset.from_subscriptable([1, 2]), Dataset.from_subscriptable([1, 3, 5]), ]) datastream = Datastream.merge([ Datastream(dataset), Datastream( dataset.subset( lambda df: [index < 3 for index in range(len(df))])), ]) assert len( dataset.subset( lambda df: [index < 3 for index in range(len(df))])) == 3 assert len(list(datastream)) == 6
def test_datastream_zip(): datasets = [ Dataset.from_subscriptable([1, 2]), Dataset.from_subscriptable([3, 4, 5]), Dataset.from_subscriptable([6, 7]), ] datastreams = [ Datastream(ds, sampler=torch.utils.data.SequentialSampler(ds)) for ds in datasets ] zipped_datastream = Datastream.zip(datastreams) batch = next(iter(zipped_datastream.data_loader(batch_size=3))) assert len(batch) == 3 and len(batch[0]) == 3 assert batch[0][0] == 1 and batch[0][1] == 2 and batch[0][2] == 1 assert batch[1][0] == 3 and batch[1][1] == 4 and batch[1][2] == 5 assert batch[2][0] == 6 and batch[2][1] == 7 and batch[2][2] == 6
def test_sequential_sampler(): from datastream.samplers import SequentialSampler dataset = Dataset.from_subscriptable(list('abc')) datastream = Datastream(dataset, SequentialSampler(len(dataset))).take(2) assert len(list(datastream.data_loader(batch_size=1))) == 2 datastream = Datastream(dataset, SequentialSampler(len(dataset))) it = iter(datastream.data_loader(batch_size=6, n_batches_per_epoch=10)) assert next(it) == ['a', 'b', 'c', 'a', 'b', 'c']
def test_merge_datastream_weights(): datasets = [ Dataset.from_subscriptable([1, 2]), Dataset.from_subscriptable([3, 4, 5]), Dataset.from_subscriptable([6, 7]), ] datastream = (Datastream.merge([ Datastream(dataset) for dataset in datasets ]).zip_index().starmap(lambda integer, index: dict( integer=integer, index=index, )).sample_proportion(0.5)) removed_indices = [0, 3] for index in removed_indices: datastream.update_example_weight_(0.0, index) samples = list(datastream.data_loader(batch_size=4, n_batches_per_epoch=4)) datastream.update_weights_(lambda weights: weights * 0.9 + 1 * 0.1)
def test_datastream_simple_weights(): dataset = Dataset.from_subscriptable([1, 2, 3, 4]) datastream = ( Datastream(dataset).zip_index().starmap(lambda integer, index: dict( integer=integer, index=index, )).sample_proportion(0.5)) removed_indices = [0, 3] for index in removed_indices: datastream.update_example_weight_(0.0, removed_indices) samples = list(datastream.data_loader(batch_size=1)) assert len(samples) == 2 for sample in samples: if sample['index'] in removed_indices: raise AssertionError( 'Samples with 0 weight were drawn from the dataset')
def test_combine_concat_merge(): dataset = Dataset.concat([ Dataset.zip([ Dataset.from_subscriptable([1]), Dataset.from_subscriptable([2]), ]), Dataset.combine([ Dataset.from_subscriptable([3, 3]), Dataset.from_subscriptable([4, 4, 4]), ]), ]) datastream = Datastream.merge([ Datastream(dataset), Datastream( Dataset.zip([ Dataset.from_subscriptable([5]), Dataset.from_subscriptable([6]), ])), ]) assert len(list(datastream)) == 2
def RandomDatastream(): return Datastream( Dataset.from_subscriptable(list(range(np.random.randint(1, 10)))))
def test_empty(): import pytest with pytest.raises(ValueError): Datastream(Dataset.from_subscriptable(list()))
def test_iter(): datastream = Datastream(Dataset.from_subscriptable(list('abc'))) assert len(list(datastream)) == 3
def test_infinite(): datastream = Datastream(Dataset.from_subscriptable(list('abc'))) it = iter(datastream.data_loader(batch_size=8, n_batches_per_epoch=10)) for _ in range(10): batch = next(it)