def test_image_pipeline_and_pin_memory(self):
        '''
        This just should not crash
        :return:
        '''
        try:
            import torch
        except ImportError:
            '''dont test if torch is not installed'''
            return

        from batchgenerators.transforms import MirrorTransform, NumpyToTensor, TransposeAxesTransform, Compose

        tr_transforms = []
        tr_transforms.append(MirrorTransform())
        tr_transforms.append(
            TransposeAxesTransform(transpose_any_of_these=(0, 1),
                                   p_per_sample=0.5))
        tr_transforms.append(NumpyToTensor(keys='data', cast_to='float'))

        composed = Compose(tr_transforms)

        dl = self.dl_images
        mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, True)

        for _ in range(50):
            res = mt.next()

        assert isinstance(res['data'], torch.Tensor)
        assert res['data'].is_pinned()

        # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent
        # the success of the test but it does not look pretty)
        sleep(2)
Esempio n. 2
0
    def test_no_crash(self):
        """
        This one should just not crash, that's all
        :return:
        """
        dl = self.dl_images
        mt_dl = MultiThreadedAugmenter(dl, None, self.num_threads, 1, None, False)

        for _ in range(20):
            _ = mt_dl.next()
Esempio n. 3
0
    def test_image_pipeline(self):
        '''
        This just should not crash
        :return:
        '''
        from batchgenerators.transforms import MirrorTransform, TransposeAxesTransform, Compose

        tr_transforms = []
        tr_transforms.append(MirrorTransform())
        tr_transforms.append(TransposeAxesTransform(transpose_any_of_these=(0, 1), p_per_sample=0.5))

        composed = Compose(tr_transforms)

        dl = self.dl_images
        mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, False)

        for _ in range(50):
            res = mt.next()

        # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent
        # the success of the test but it does not look pretty)
        sleep(2)