示例#1
0
    def __init__(self,
                 which_set='train',
                 batch_size=16,
                 mask_type='mix',
                 dilation=True,
                 noisy_texture=True,
                 rotation=True,
                 image_size=256,
                 num_workers=0):
        """
        Initialization of synthetic root data loader
        :param which_set: 'train', 'valid', 'test'
        :param batch_size: batch size (how many small patches to be extracted from a single segmentation mask)
        :param mask_type: gap type 'square'|'blob'|'brush'|'mix'
        :param dilation: root dilation
        :param noisy_texture: noisy texture
        :param rotation: root rotation
        :param image_size: patch size
        :param num_workers: number of workers, normally set to 0
        """

        assert mask_type in ['square', 'blob', 'brush', 'mix']
        if mask_type in ['blob', 'mix']:
            self.total_blob_masks = get_blob_masks(blob_masks_path)
        else:
            self.total_blob_masks = None

        self.mask_type = mask_type
        self.batch_size = batch_size
        self.num_workers = num_workers
        # synthetic root dataset
        self.dataset = SyntheticRootDataset(which_set=which_set,
                                            dilation=dilation,
                                            noisy_texture=noisy_texture,
                                            rotation=rotation)
        self.n_samples = len(self.dataset)
        self.image_size = image_size

        # set shuffle of dataset
        if self.dataset.training:
            self.shuffle = True
        else:
            self.shuffle = False

        super(SyntheticRootDataLoader, self).__init__(
            dataset=self.dataset,
            batch_size=
            1,  #batch_size set to 1 as we use only 1 full images to extract many patches
            shuffle=self.shuffle,
            num_workers=self.num_workers,
            collate_fn=partial(full_seg_collate_fn,
                               batch_size=self.batch_size,
                               mask_type=mask_type,
                               image_size=image_size,
                               total_blob_masks=self.total_blob_masks,
                               training=self.dataset.training))
示例#2
0
    def __init__(self,
                 which_set='train',
                 batch_size=16,
                 mask_type='mix',
                 image_size=256,
                 num_workers=0):
        """
        Initialization of line data loader
        :param which_set: 'train'|'valid'|'test'
        :param batch_size: number of patches
        :param mask_type: 'square'|'brush'|'blob'|'mix'
        :param image_size: patch size
        :param num_workers: normally set to 0
        """

        assert mask_type in ['square', 'blob', 'brush', 'mix']
        if mask_type in ['blob', 'mix']:
            self.total_blob_masks = get_blob_masks(blob_masks_path)
        else:
            self.total_blob_masks = None

        self.mask_type = mask_type
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.dataset = LineDataset(which_set=which_set)
        self.n_samples = len(self.dataset)
        self.image_size = image_size

        # set shuffle of dataset
        if self.dataset.training:
            self.shuffle = True
        else:
            self.shuffle = False

        super(LineDataLoader, self).__init__(
            dataset=self.dataset,
            batch_size=
            1,  # batch_size set to 1 as we use only 1 full images to extract many patches
            shuffle=self.shuffle,
            num_workers=self.num_workers,
            collate_fn=partial(full_seg_collate_fn,
                               batch_size=self.batch_size,
                               mask_type=mask_type,
                               image_size=image_size,
                               total_blob_masks=self.total_blob_masks,
                               training=self.dataset.training))
示例#3
0
    def __init__(self,
                 which_set='train',
                 batch_size=16,
                 mask_type='mix',
                 num_workers=0):
        """
        Initialization of chickpea patch data loader
        :param which_set: 'train'|'test'|'valid'
        :param batch_size: number of patches to be extracted from a single segmentation mask
        :param mask_type: 'square'|'brush'|'blob'|'mix'
        :param num_workers: 0
        """

        assert mask_type in ['square', 'blob', 'brush', 'mix']
        if mask_type in ['blob', 'mix']:
            self.total_blob_masks = get_blob_masks(blob_masks_path)
        else:
            self.total_blob_masks = None

        self.mask_type = mask_type
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.dataset = ChickpeaPatchRootDataset(which_set=which_set)
        self.n_samples = len(self.dataset)
        # set shuffle of dataset
        if self.dataset.training:
            self.shuffle = True
        else:
            self.shuffle = False

        super(ChickpeaPatchRootDataLoader, self).__init__(
            dataset=self.dataset,
            batch_size=
            batch_size,  #batch_size set to 1 as we use only 1 full images to extract many patches
            shuffle=self.shuffle,
            num_workers=self.num_workers,
            collate_fn=partial(chickpea_patch_collate_fn,
                               mask_type=mask_type,
                               total_blob_masks=self.total_blob_masks,
                               training=self.dataset.training))