def test_combined_pk_seq_sampler():
    num_pids = 40
    pk_cfg = {
        "P": 6,
        "K": 2
    }
    sequential_cfg = {
        "batch_size": 5
    }
    size = 150
    size2 = 200
    dataset1 = DummyDataset(lambda : create_dummy_pid_data(size, num_pids), "dummy1")
    dataset2 = DummyDataset(lambda : create_dummy_pid_data(size2, num_pids), "dummy2")

    pk_sampler = PKSampler.build(dataset1, pk_cfg)
    seq_sampler = SequentialSampler.build(dataset2, sequential_cfg)

    sampler = ConcatenatedSamplerLongest([pk_sampler, seq_sampler])
    batches_seq = len(seq_sampler)
    batches_pk = len(pk_sampler)
    iterations_pk = batches_seq // batches_pk
    for batch in sampler:
        batch_counter = {dataset1.name: 0, dataset2.name: 0}
        for dataset, idx in batch:
            batch_counter[dataset] += 1
        assert batch_counter[dataset1.name] == pk_sampler.batch_size
        assert batch_counter[dataset2.name] == seq_sampler.batch_size

    print(iterations_pk, pk_sampler.completed_iter)
    assert pk_sampler.completed_iter == iterations_pk
Exemple #2
0
def test_concat_dataset():
    size1 = 70
    size2 = 100
    name1 = "Dummy1"
    name2 = "Dummy2"
    dataset1 = DummyDataset(lambda: create_dummy_pid_data(size1, 30, name1),
                            name1)
    dataset2 = DummyDataset(lambda: create_dummy_data(size2, name2), name2)

    dataset = ConcatDataset([dataset1, dataset2])
    assert len(dataset) == size1 + size2

    sampler = SequentialSampler(dataset)

    collate_fn = build_collate_fn(dataset.header)

    dataloader = DataLoader(dataset,
                            sampler=sampler,
                            num_workers=1,
                            collate_fn=collate_fn)

    for idx, data in enumerate(dataloader):
        if idx < size1:
            # returns seq samplerbatch of 1
            assert data['path'][0].startswith(name1)
            assert data['pid'][0] != -1
        else:
            assert data['path'][0].startswith(name2)
            assert data['pid'][0] == -1
Exemple #3
0
def test_concat_reid_dataset():
    size1 = 70
    size2 = 100
    name1 = "Dummy1"
    name2 = "Dummy2"
    pid1 = 30
    pid2 = 30
    dataset1 = DummyDataset(lambda: create_dummy_pid_data(size1, pid1, name1),
                            name1)
    dataset2 = DummyDataset(lambda: create_dummy_pid_data(size2, pid2, name2),
                            name2)
    dataset = ConcatReidDataset([dataset1, dataset2])
    assert dataset.num_labels == pid1 + pid2
def test_reset():
    size1 = 10
    size2 = 20
    name1 = "Dummy1"
    name2 = "Dummy2"
    dataset1 = DummyDataset(lambda: create_dummy_data(size1, name1), name1)
    dataset2 = DummyDataset(lambda: create_dummy_data(size2, name2), name2)

    sampler1 = BatchSampler(SequentialSampler(dataset1), 1, True)
    sampler2 = BatchSampler(SequentialSampler(dataset2), 1, True)

    sampler = RandomSamplerShortest( [sampler1, sampler2])

    print(len(sampler))
    for _ in range(3):
        for idx, batch in enumerate(sampler):
            print(idx, batch)
def test_combined_pk_pk_sampler():
    size1 = 200
    size2 = 300
    num_pids1 = 30
    num_pids2 = 40
    
    dataset1 = DummyDataset(lambda : create_dummy_pid_data(size1, num_pids1, "d1"), "dummy1")
    dataset2 = DummyDataset(lambda : create_dummy_pid_data(size2, num_pids2, "d2"), "dummy2")

    P = 4
    K = 3
    pk_sampler1 = PKSampler(P, K, dataset1, drop_last=True)
    pk_sampler2 = PKSampler(P, K, dataset2, drop_last=True)

    sampler = ConcatenatedSamplerLongest( [pk_sampler1, pk_sampler2])
    for batch in sampler:
        for idx in batch:
            pass
def test_random_sampler_length_weighted():
    """
    TODO this is a check not a test.
    """
    size1 = 1
    size2 = 100
    name1 = "Dummy1"
    name2 = "Dummy2"
    dataset1 = DummyDataset(lambda: create_dummy_data(size1, name1), name1)
    dataset2 = DummyDataset(lambda: create_dummy_data(size2, name2), name2)

    sampler1 = BatchSampler(SequentialSampler(dataset1), 1, True)
    sampler2 = BatchSampler(SequentialSampler(dataset2), 1, True)

    sampler = RandomSamplerLengthWeighted([sampler1, sampler2], [1, 1])

    print(len(sampler))
    for idx, batch in enumerate(sampler):
        print(idx, batch)
def test_shortest_concatenated_sampler():
    size1 = 70
    size2 = 100
    name1 = "Dummy1"
    name2 = "Dummy2"
    dataset1 = DummyDataset(lambda: create_dummy_data(size1, name1), name1)
    dataset2 = DummyDataset(lambda: create_dummy_data(size2, name2), name2)
    
    sampler1 = BatchSampler(SequentialSampler(dataset1), 1, True)
    sampler2 = BatchSampler(SequentialSampler(dataset2), 1, True)

    sampler = ConcatenatedSamplerShortest([sampler1, sampler2])
    
    print(len(sampler))
    for idx, batch in enumerate(sampler):
        pass
    
    correct = size1 if size1 < size2 else size2
    assert idx + 1 == correct
def test_random_sampler():
    dataset = DummyDataset(lambda: create_dummy_data(100), "dummy")
    sampler = RandomSampler(dataset)
    idxs1 = []
    for idxs in sampler:
        idxs1.append(idxs)

    idxs2 = []
    for idxs in sampler:
        idxs2.append(idxs)

    print(idxs1, idxs2)
    assert idxs1 != idxs2
def test_switching_sampler_longest():
    size1 = 70
    size2 = 100
    size3 = 150
    name1 = "Dummy1"
    name2 = "Dummy2"
    name3 = "Dummy3"
    dataset1 = DummyDataset(lambda: create_dummy_data(size1, name1), name1)
    dataset2 = DummyDataset(lambda: create_dummy_data(size2, name2), name2)
    dataset3 = DummyDataset(lambda: create_dummy_data(size3, name3), name3)

    sampler1 = BatchSampler(SequentialSampler(dataset1), 1, True)
    sampler2 = BatchSampler(SequentialSampler(dataset2), 1, True)
    sampler3 = BatchSampler(SequentialSampler(dataset3), 1, True)

    sampler = SwitchingSamplerLongest([sampler1, sampler2, sampler3])

    print(len(sampler))
    for idx, batch in enumerate(sampler):
        if idx < 70 * 3:
            if idx % 3 == 0:
                assert batch[0][0] == name1
            elif idx % 3 == 1:
                assert batch[0][0] == name2
            elif idx % 3 == 2:
                assert batch[0][0] == name3
        elif idx < 70 * 3 + 30 * 2:
            if idx % 2 == 0:
                assert batch[0][0] == name2
            elif idx % 2 == 1:
                assert batch[0][0] == name3
        else:
            assert batch[0][0] == name3
    
    correct = size1 + size2 + size3
    assert idx + 1 == correct
def test_create_pid2idxs():
    num_pids = 40
    dataset = DummyDataset(lambda: create_dummy_pid_data(200, num_pids), "dummy")

    pid2idxs = create_pids2idxs(dataset)
    assert len(pid2idxs) == num_pids