def test_combined_pk_seq_sampler(): num_pids = 40 pk_cfg = { "P": 6, "K": 2 } sequential_cfg = { "batch_size": 5 } size = 150 size2 = 200 dataset1 = DummyDataset(lambda : create_dummy_pid_data(size, num_pids), "dummy1") dataset2 = DummyDataset(lambda : create_dummy_pid_data(size2, num_pids), "dummy2") pk_sampler = PKSampler.build(dataset1, pk_cfg) seq_sampler = SequentialSampler.build(dataset2, sequential_cfg) sampler = ConcatenatedSamplerLongest([pk_sampler, seq_sampler]) batches_seq = len(seq_sampler) batches_pk = len(pk_sampler) iterations_pk = batches_seq // batches_pk for batch in sampler: batch_counter = {dataset1.name: 0, dataset2.name: 0} for dataset, idx in batch: batch_counter[dataset] += 1 assert batch_counter[dataset1.name] == pk_sampler.batch_size assert batch_counter[dataset2.name] == seq_sampler.batch_size print(iterations_pk, pk_sampler.completed_iter) assert pk_sampler.completed_iter == iterations_pk
def test_concat_dataset(): size1 = 70 size2 = 100 name1 = "Dummy1" name2 = "Dummy2" dataset1 = DummyDataset(lambda: create_dummy_pid_data(size1, 30, name1), name1) dataset2 = DummyDataset(lambda: create_dummy_data(size2, name2), name2) dataset = ConcatDataset([dataset1, dataset2]) assert len(dataset) == size1 + size2 sampler = SequentialSampler(dataset) collate_fn = build_collate_fn(dataset.header) dataloader = DataLoader(dataset, sampler=sampler, num_workers=1, collate_fn=collate_fn) for idx, data in enumerate(dataloader): if idx < size1: # returns seq samplerbatch of 1 assert data['path'][0].startswith(name1) assert data['pid'][0] != -1 else: assert data['path'][0].startswith(name2) assert data['pid'][0] == -1
def test_concat_reid_dataset(): size1 = 70 size2 = 100 name1 = "Dummy1" name2 = "Dummy2" pid1 = 30 pid2 = 30 dataset1 = DummyDataset(lambda: create_dummy_pid_data(size1, pid1, name1), name1) dataset2 = DummyDataset(lambda: create_dummy_pid_data(size2, pid2, name2), name2) dataset = ConcatReidDataset([dataset1, dataset2]) assert dataset.num_labels == pid1 + pid2
def test_reset(): size1 = 10 size2 = 20 name1 = "Dummy1" name2 = "Dummy2" dataset1 = DummyDataset(lambda: create_dummy_data(size1, name1), name1) dataset2 = DummyDataset(lambda: create_dummy_data(size2, name2), name2) sampler1 = BatchSampler(SequentialSampler(dataset1), 1, True) sampler2 = BatchSampler(SequentialSampler(dataset2), 1, True) sampler = RandomSamplerShortest( [sampler1, sampler2]) print(len(sampler)) for _ in range(3): for idx, batch in enumerate(sampler): print(idx, batch)
def test_combined_pk_pk_sampler(): size1 = 200 size2 = 300 num_pids1 = 30 num_pids2 = 40 dataset1 = DummyDataset(lambda : create_dummy_pid_data(size1, num_pids1, "d1"), "dummy1") dataset2 = DummyDataset(lambda : create_dummy_pid_data(size2, num_pids2, "d2"), "dummy2") P = 4 K = 3 pk_sampler1 = PKSampler(P, K, dataset1, drop_last=True) pk_sampler2 = PKSampler(P, K, dataset2, drop_last=True) sampler = ConcatenatedSamplerLongest( [pk_sampler1, pk_sampler2]) for batch in sampler: for idx in batch: pass
def test_random_sampler_length_weighted(): """ TODO this is a check not a test. """ size1 = 1 size2 = 100 name1 = "Dummy1" name2 = "Dummy2" dataset1 = DummyDataset(lambda: create_dummy_data(size1, name1), name1) dataset2 = DummyDataset(lambda: create_dummy_data(size2, name2), name2) sampler1 = BatchSampler(SequentialSampler(dataset1), 1, True) sampler2 = BatchSampler(SequentialSampler(dataset2), 1, True) sampler = RandomSamplerLengthWeighted([sampler1, sampler2], [1, 1]) print(len(sampler)) for idx, batch in enumerate(sampler): print(idx, batch)
def test_shortest_concatenated_sampler(): size1 = 70 size2 = 100 name1 = "Dummy1" name2 = "Dummy2" dataset1 = DummyDataset(lambda: create_dummy_data(size1, name1), name1) dataset2 = DummyDataset(lambda: create_dummy_data(size2, name2), name2) sampler1 = BatchSampler(SequentialSampler(dataset1), 1, True) sampler2 = BatchSampler(SequentialSampler(dataset2), 1, True) sampler = ConcatenatedSamplerShortest([sampler1, sampler2]) print(len(sampler)) for idx, batch in enumerate(sampler): pass correct = size1 if size1 < size2 else size2 assert idx + 1 == correct
def test_random_sampler(): dataset = DummyDataset(lambda: create_dummy_data(100), "dummy") sampler = RandomSampler(dataset) idxs1 = [] for idxs in sampler: idxs1.append(idxs) idxs2 = [] for idxs in sampler: idxs2.append(idxs) print(idxs1, idxs2) assert idxs1 != idxs2
def test_switching_sampler_longest(): size1 = 70 size2 = 100 size3 = 150 name1 = "Dummy1" name2 = "Dummy2" name3 = "Dummy3" dataset1 = DummyDataset(lambda: create_dummy_data(size1, name1), name1) dataset2 = DummyDataset(lambda: create_dummy_data(size2, name2), name2) dataset3 = DummyDataset(lambda: create_dummy_data(size3, name3), name3) sampler1 = BatchSampler(SequentialSampler(dataset1), 1, True) sampler2 = BatchSampler(SequentialSampler(dataset2), 1, True) sampler3 = BatchSampler(SequentialSampler(dataset3), 1, True) sampler = SwitchingSamplerLongest([sampler1, sampler2, sampler3]) print(len(sampler)) for idx, batch in enumerate(sampler): if idx < 70 * 3: if idx % 3 == 0: assert batch[0][0] == name1 elif idx % 3 == 1: assert batch[0][0] == name2 elif idx % 3 == 2: assert batch[0][0] == name3 elif idx < 70 * 3 + 30 * 2: if idx % 2 == 0: assert batch[0][0] == name2 elif idx % 2 == 1: assert batch[0][0] == name3 else: assert batch[0][0] == name3 correct = size1 + size2 + size3 assert idx + 1 == correct
def test_create_pid2idxs(): num_pids = 40 dataset = DummyDataset(lambda: create_dummy_pid_data(200, num_pids), "dummy") pid2idxs = create_pids2idxs(dataset) assert len(pid2idxs) == num_pids