def __init__(self, which_set='train', batch_size=16, mask_type='mix', dilation=True, noisy_texture=True, rotation=True, image_size=256, num_workers=0): """ Initialization of synthetic root data loader :param which_set: 'train', 'valid', 'test' :param batch_size: batch size (how many small patches to be extracted from a single segmentation mask) :param mask_type: gap type 'square'|'blob'|'brush'|'mix' :param dilation: root dilation :param noisy_texture: noisy texture :param rotation: root rotation :param image_size: patch size :param num_workers: number of workers, normally set to 0 """ assert mask_type in ['square', 'blob', 'brush', 'mix'] if mask_type in ['blob', 'mix']: self.total_blob_masks = get_blob_masks(blob_masks_path) else: self.total_blob_masks = None self.mask_type = mask_type self.batch_size = batch_size self.num_workers = num_workers # synthetic root dataset self.dataset = SyntheticRootDataset(which_set=which_set, dilation=dilation, noisy_texture=noisy_texture, rotation=rotation) self.n_samples = len(self.dataset) self.image_size = image_size # set shuffle of dataset if self.dataset.training: self.shuffle = True else: self.shuffle = False super(SyntheticRootDataLoader, self).__init__( dataset=self.dataset, batch_size= 1, #batch_size set to 1 as we use only 1 full images to extract many patches shuffle=self.shuffle, num_workers=self.num_workers, collate_fn=partial(full_seg_collate_fn, batch_size=self.batch_size, mask_type=mask_type, image_size=image_size, total_blob_masks=self.total_blob_masks, training=self.dataset.training))
def __init__(self, which_set='train', batch_size=16, mask_type='mix', image_size=256, num_workers=0): """ Initialization of line data loader :param which_set: 'train'|'valid'|'test' :param batch_size: number of patches :param mask_type: 'square'|'brush'|'blob'|'mix' :param image_size: patch size :param num_workers: normally set to 0 """ assert mask_type in ['square', 'blob', 'brush', 'mix'] if mask_type in ['blob', 'mix']: self.total_blob_masks = get_blob_masks(blob_masks_path) else: self.total_blob_masks = None self.mask_type = mask_type self.batch_size = batch_size self.num_workers = num_workers self.dataset = LineDataset(which_set=which_set) self.n_samples = len(self.dataset) self.image_size = image_size # set shuffle of dataset if self.dataset.training: self.shuffle = True else: self.shuffle = False super(LineDataLoader, self).__init__( dataset=self.dataset, batch_size= 1, # batch_size set to 1 as we use only 1 full images to extract many patches shuffle=self.shuffle, num_workers=self.num_workers, collate_fn=partial(full_seg_collate_fn, batch_size=self.batch_size, mask_type=mask_type, image_size=image_size, total_blob_masks=self.total_blob_masks, training=self.dataset.training))
def __init__(self, which_set='train', batch_size=16, mask_type='mix', num_workers=0): """ Initialization of chickpea patch data loader :param which_set: 'train'|'test'|'valid' :param batch_size: number of patches to be extracted from a single segmentation mask :param mask_type: 'square'|'brush'|'blob'|'mix' :param num_workers: 0 """ assert mask_type in ['square', 'blob', 'brush', 'mix'] if mask_type in ['blob', 'mix']: self.total_blob_masks = get_blob_masks(blob_masks_path) else: self.total_blob_masks = None self.mask_type = mask_type self.batch_size = batch_size self.num_workers = num_workers self.dataset = ChickpeaPatchRootDataset(which_set=which_set) self.n_samples = len(self.dataset) # set shuffle of dataset if self.dataset.training: self.shuffle = True else: self.shuffle = False super(ChickpeaPatchRootDataLoader, self).__init__( dataset=self.dataset, batch_size= batch_size, #batch_size set to 1 as we use only 1 full images to extract many patches shuffle=self.shuffle, num_workers=self.num_workers, collate_fn=partial(chickpea_patch_collate_fn, mask_type=mask_type, total_blob_masks=self.total_blob_masks, training=self.dataset.training))