def get_data(name): isTrain = name == 'train' ds = dataset.BSDS500(name, shuffle=True) class CropMultiple16(imgaug.ImageAugmentor): def _get_augment_params(self, img): newh = img.shape[0] // 16 * 16 neww = img.shape[1] // 16 * 16 assert newh > 0 and neww > 0 diffh = img.shape[0] - newh h0 = 0 if diffh == 0 else self.rng.randint(diffh) diffw = img.shape[1] - neww w0 = 0 if diffw == 0 else self.rng.randint(diffw) return (h0, w0, newh, neww) def _augment(self, img, param): h0, w0, newh, neww = param return img[h0:h0 + newh, w0:w0 + neww] if isTrain: print("IN TRAIN") pass shape_aug = [ imgaug.RandomResize(xrange=(0.7, 1.5), yrange=(0.7, 1.5), aspect_ratio_thres=0.15), imgaug.RotationAndCropValid(90), CropMultiple16(), imgaug.Flip(horiz=True), imgaug.Flip(vert=True) ] else: print("NOT IN TRAIN") # the original image shape (321x481) in BSDS is not a multiple of 16 IMAGE_SHAPE = (320, 480) shape_aug = [imgaug.CenterCrop(IMAGE_SHAPE)] ds = AugmentImageComponents(ds, shape_aug, (0, 1), copy=False) def f(m): # thresholding m[m >= 0.50] = 1 m[m < 0.50] = 0 return m ds = MapDataComponent(ds, f, 1) if isTrain: augmentors = [ imgaug.Brightness(63, clip=False), imgaug.Contrast((0.4, 1.5)), ] ds = AugmentImageComponent(ds, augmentors, copy=False) ds = BatchDataByShape(ds, 4, idx=0) ds = PrefetchDataZMQ(ds, 1) else: ds = BatchData(ds, 1) return ds
def get_data(): # probably not the best dataset ds = dataset.BSDS500('train', shuffle=True) ds = AugmentImageComponent(ds, [imgaug.Grayscale(keepdims=False), imgaug.Resize((SHAPE, SHAPE))]) ds = ThetaImages(ds) ds = RepeatedData(ds, 50) # just pretend this dataset is bigger # this pre-computation is pretty heavy ds = PrefetchDataZMQ(ds, min(20, multiprocessing.cpu_count())) ds = BatchData(ds, BATCH) return ds