Exemple #1
0
def test_concat_dataset():
    size1 = 70
    size2 = 100
    name1 = "Dummy1"
    name2 = "Dummy2"
    dataset1 = DummyDataset(lambda: create_dummy_pid_data(size1, 30, name1),
                            name1)
    dataset2 = DummyDataset(lambda: create_dummy_data(size2, name2), name2)

    dataset = ConcatDataset([dataset1, dataset2])
    assert len(dataset) == size1 + size2

    sampler = SequentialSampler(dataset)

    collate_fn = build_collate_fn(dataset.header)

    dataloader = DataLoader(dataset,
                            sampler=sampler,
                            num_workers=1,
                            collate_fn=collate_fn)

    for idx, data in enumerate(dataloader):
        if idx < size1:
            # returns seq samplerbatch of 1
            assert data['path'][0].startswith(name1)
            assert data['pid'][0] != -1
        else:
            assert data['path'][0].startswith(name2)
            assert data['pid'][0] == -1
def test_reset():
    size1 = 10
    size2 = 20
    name1 = "Dummy1"
    name2 = "Dummy2"
    dataset1 = DummyDataset(lambda: create_dummy_data(size1, name1), name1)
    dataset2 = DummyDataset(lambda: create_dummy_data(size2, name2), name2)

    sampler1 = BatchSampler(SequentialSampler(dataset1), 1, True)
    sampler2 = BatchSampler(SequentialSampler(dataset2), 1, True)

    sampler = RandomSamplerShortest( [sampler1, sampler2])

    print(len(sampler))
    for _ in range(3):
        for idx, batch in enumerate(sampler):
            print(idx, batch)
def test_random_sampler_length_weighted():
    """
    TODO this is a check not a test.
    """
    size1 = 1
    size2 = 100
    name1 = "Dummy1"
    name2 = "Dummy2"
    dataset1 = DummyDataset(lambda: create_dummy_data(size1, name1), name1)
    dataset2 = DummyDataset(lambda: create_dummy_data(size2, name2), name2)

    sampler1 = BatchSampler(SequentialSampler(dataset1), 1, True)
    sampler2 = BatchSampler(SequentialSampler(dataset2), 1, True)

    sampler = RandomSamplerLengthWeighted([sampler1, sampler2], [1, 1])

    print(len(sampler))
    for idx, batch in enumerate(sampler):
        print(idx, batch)
def test_shortest_concatenated_sampler():
    size1 = 70
    size2 = 100
    name1 = "Dummy1"
    name2 = "Dummy2"
    dataset1 = DummyDataset(lambda: create_dummy_data(size1, name1), name1)
    dataset2 = DummyDataset(lambda: create_dummy_data(size2, name2), name2)
    
    sampler1 = BatchSampler(SequentialSampler(dataset1), 1, True)
    sampler2 = BatchSampler(SequentialSampler(dataset2), 1, True)

    sampler = ConcatenatedSamplerShortest([sampler1, sampler2])
    
    print(len(sampler))
    for idx, batch in enumerate(sampler):
        pass
    
    correct = size1 if size1 < size2 else size2
    assert idx + 1 == correct
def test_random_sampler():
    dataset = DummyDataset(lambda: create_dummy_data(100), "dummy")
    sampler = RandomSampler(dataset)
    idxs1 = []
    for idxs in sampler:
        idxs1.append(idxs)

    idxs2 = []
    for idxs in sampler:
        idxs2.append(idxs)

    print(idxs1, idxs2)
    assert idxs1 != idxs2
def test_switching_sampler_longest():
    size1 = 70
    size2 = 100
    size3 = 150
    name1 = "Dummy1"
    name2 = "Dummy2"
    name3 = "Dummy3"
    dataset1 = DummyDataset(lambda: create_dummy_data(size1, name1), name1)
    dataset2 = DummyDataset(lambda: create_dummy_data(size2, name2), name2)
    dataset3 = DummyDataset(lambda: create_dummy_data(size3, name3), name3)

    sampler1 = BatchSampler(SequentialSampler(dataset1), 1, True)
    sampler2 = BatchSampler(SequentialSampler(dataset2), 1, True)
    sampler3 = BatchSampler(SequentialSampler(dataset3), 1, True)

    sampler = SwitchingSamplerLongest([sampler1, sampler2, sampler3])

    print(len(sampler))
    for idx, batch in enumerate(sampler):
        if idx < 70 * 3:
            if idx % 3 == 0:
                assert batch[0][0] == name1
            elif idx % 3 == 1:
                assert batch[0][0] == name2
            elif idx % 3 == 2:
                assert batch[0][0] == name3
        elif idx < 70 * 3 + 30 * 2:
            if idx % 2 == 0:
                assert batch[0][0] == name2
            elif idx % 2 == 1:
                assert batch[0][0] == name3
        else:
            assert batch[0][0] == name3
    
    correct = size1 + size2 + size3
    assert idx + 1 == correct