Example #1
0
    def test_drop_non_intersecting_removes_all_elements_with_empty_intersection(self):
        dataloader = VerticalDataLoader(self.dataset, batch_size=100)
        intersection = []

        dataloader.drop_non_intersecting(intersection, intersection)

        assert len(dataloader.dataloader1.dataset.data) == 0
        assert len(dataloader.dataloader1.dataset.ids) == 0
        assert len(dataloader.dataloader2.dataset.targets) == 0
        assert len(dataloader.dataloader2.dataset.ids) == 0
Example #2
0
    def test_drop_non_intersecting_removes_elements(self):
        dataloader = VerticalDataLoader(self.dataset, batch_size=100)
        sample_datapoint = dataloader.dataloader1.dataset.data[0]
        intersection = [0, 1, 2]

        dataloader.drop_non_intersecting(intersection, intersection)

        assert len(dataloader.dataloader1.dataset.data) == 3
        assert len(dataloader.dataloader1.dataset.ids) == 3
        assert len(dataloader.dataloader2.dataset.targets) == 3
        assert len(dataloader.dataloader2.dataset.ids) == 3
        assert torch.equal(sample_datapoint, dataloader.dataloader1.dataset.data[0])
Example #3
0
    def test_datasets_have_same_ids_after_drop_non_intersecting(self):
        dataloader = VerticalDataLoader(self.dataset, batch_size=128)

        intersection1 = [0, 1, 5, 10]
        ids1 = [dataloader.dataloader1.dataset.ids[i] for i in intersection1]

        intersection2 = [7, 10, 12, 1]
        ids2 = [dataloader.dataloader2.dataset.ids[i] for i in intersection2]

        dataloader.drop_non_intersecting(intersection1, intersection2)

        assert len(dataloader.dataloader1.dataset.data) == 4
        assert (dataloader.dataloader1.dataset.ids == ids1).all()

        assert len(dataloader.dataloader2.dataset.targets) == 4
        assert (dataloader.dataloader2.dataset.ids == ids2).all()