Пример #1
0
def get_data(name):
    isTrain = name == 'train'
    ds = dataset.BSDS500(name, shuffle=True)

    class CropMultiple16(imgaug.ImageAugmentor):
        def _get_augment_params(self, img):
            newh = img.shape[0] // 16 * 16
            neww = img.shape[1] // 16 * 16
            assert newh > 0 and neww > 0
            diffh = img.shape[0] - newh
            h0 = 0 if diffh == 0 else self.rng.randint(diffh)
            diffw = img.shape[1] - neww
            w0 = 0 if diffw == 0 else self.rng.randint(diffw)
            return (h0, w0, newh, neww)

        def _augment(self, img, param):
            h0, w0, newh, neww = param
            return img[h0:h0 + newh, w0:w0 + neww]

    if isTrain:
        print("IN TRAIN")
        pass
        shape_aug = [
            imgaug.RandomResize(xrange=(0.7, 1.5),
                                yrange=(0.7, 1.5),
                                aspect_ratio_thres=0.15),
            imgaug.RotationAndCropValid(90),
            CropMultiple16(),
            imgaug.Flip(horiz=True),
            imgaug.Flip(vert=True)
        ]
    else:
        print("NOT IN TRAIN")
        # the original image shape (321x481) in BSDS is not a multiple of 16
        IMAGE_SHAPE = (320, 480)
        shape_aug = [imgaug.CenterCrop(IMAGE_SHAPE)]
    ds = AugmentImageComponents(ds, shape_aug, (0, 1), copy=False)

    def f(m):  # thresholding
        m[m >= 0.50] = 1
        m[m < 0.50] = 0
        return m

    ds = MapDataComponent(ds, f, 1)

    if isTrain:
        augmentors = [
            imgaug.Brightness(63, clip=False),
            imgaug.Contrast((0.4, 1.5)),
        ]
        ds = AugmentImageComponent(ds, augmentors, copy=False)
        ds = BatchDataByShape(ds, 4, idx=0)
        ds = PrefetchDataZMQ(ds, 1)
    else:
        ds = BatchData(ds, 1)
    return ds
Пример #2
0
def get_data():
    # probably not the best dataset
    ds = dataset.BSDS500('train', shuffle=True)
    ds = AugmentImageComponent(ds, [imgaug.Grayscale(keepdims=False),
                                    imgaug.Resize((SHAPE, SHAPE))])
    ds = ThetaImages(ds)
    ds = RepeatedData(ds, 50)  # just pretend this dataset is bigger
    # this pre-computation is pretty heavy
    ds = PrefetchDataZMQ(ds, min(20, multiprocessing.cpu_count()))
    ds = BatchData(ds, BATCH)
    return ds