示例#1
0
def get_data(file_name):
    if file_name.endswith('.lmdb'):
        ds = LMDBDataPoint(file_name, shuffle=True)
        ds = ImageDecode(ds, index=0)
    elif file_name.endswith('.zip'):
        ds = ImageDataFromZIPFile(file_name, shuffle=True)
        ds = ImageDecode(ds, index=0)
        ds = RejectTooSmallImages(ds, index=0)
        ds = CenterSquareResize(ds, index=0)
    else:
        raise ValueError("Unknown file format " + file_name)
    augmentors = [imgaug.RandomCrop(128),
                  imgaug.Flip(horiz=True)]
    ds = AugmentImageComponent(ds, augmentors, index=0, copy=True)
    ds = MapData(ds, lambda x: [cv2.resize(x[0], (32, 32), interpolation=cv2.INTER_CUBIC), x[0]])
    ds = PrefetchDataZMQ(ds, 3)
    ds = BatchData(ds, BATCH_SIZE)
    return ds
def get_data(file_name, train_or_test):
    isTrain = train_or_test == 'train'

    if file_name.endswith('.lmdb'):
        ds = LMDBSerializer.load(file_name, shuffle=True)
        if config.USE_YCBCR is True:
            ds = ImageDecodeYCrCb(ds, index=0)
        else:
            ds = ImageDecodeBGR(ds, index=0)
    elif file_name.endswith('.zip'):
        ds = ImageDataFromZIPFile(file_name, shuffle=True)
        if config.USE_YCBCR is True:
            ds = ImageDecodeYCrCb(ds, index=0)
        else:
            ds = ImageDecodeBGR(ds, index=0)
        ds = RejectTooSmallImages(ds, thresh=config.INPUT_IMAGE_SIZE, index=0)
        # ds = CenterSquareResize(ds, index=0)
    else:
        raise ValueError("Unknown file format " + file_name)

    if isTrain:
        augmentors = [
            imgaug.RandomCrop(100),
            # imgaug.RandomApplyAug(imgaug.RandomChooseAug([
            #     imgaug.SaltPepperNoise(white_prob=0.01, black_prob=0.01),
            #     imgaug.RandomOrderAug([
            #         imgaug.BrightnessScale((0.98, 1.02), clip=True),
            #         # imgaug.Contrast((0.98, 1.02), rgb=None, clip=True),
            #         # imgaug.Saturation(0.4, rgb=False),  # only for RGB or BGR images!
            #         ]),
            #     ]), 0.7),
            # imgaug.SaltPepperNoise(white_prob=0.01, black_prob=0.01),
            imgaug.RandomApplyAug(
                imgaug.RandomOrderAug([
                    imgaug.Flip(horiz=True),
                    imgaug.Flip(vert=True),
                    imgaug.Rotation(180, (0, 1), cv2.INTER_CUBIC, step_deg=90)
                    # imgaug.BrightnessScale((0.98, 1.02), clip=True),
                    # imgaug.Contrast((0.98, 1.02), rgb=None, clip=True),
                    # imgaug.Saturation(0.4, rgb=False),  # only for RGB or BGR images!
                ]),
                0.7),

            # imgaug.MinMaxNormalize(0.0001, config.NORMALIZE, all_channel=True),
            # MinMaxNormalize(min=0, max=config.NORMALIZE, all_channel=False),
        ]
    else:
        augmentors = [
            imgaug.RandomCrop(100),
            # imgaug.MinMaxNormalize(min=0, max=config.NORMALIZE, all_channel=False),
        ]

    ds = AugmentImageComponent(ds, augmentors, index=0, copy=True)

    # if isTrain:
    #     ds = PrefetchData(ds, 2, 2)

    scaled_size = config.INPUT_IMAGE_SIZE / config.SCALE

    # ds = MapData(ds, lambda x: [np.expand_dims(cv2.resize(x[0], (scaled_size, scaled_size), interpolation=cv2.INTER_CUBIC), axis=3),
    #                             np.expand_dims(x[0], axis=3),
    #                             np.expand_dims(cv2.resize(cv2.resize(x[0], (scaled_size, scaled_size), interpolation=cv2.INTER_CUBIC), (config.INPUT_IMAGE_SIZE, config.INPUT_IMAGE_SIZE), interpolation=cv2.INTER_CUBIC),axis=3),
    #                             ])

    # ds = MapData(ds, lambda x: [np.reshape(cv2.resize(x[0], None, fx=1. / config.SCALE, fy=1. / config.SCALE, interpolation=cv2.INTER_CUBIC), (cv2.resize(x[0], None, fx=1. / config.SCALE, fy=1. / config.SCALE, interpolation=cv2.INTER_CUBIC).shape[0], cv2.resize(x[0], None, fx=1. / config.SCALE, fy=1. / config.SCALE, interpolation=cv2.INTER_CUBIC).shape[1], 1)),
    #                             np.expand_dims(x[0], axis=3),
    #                             np.reshape(cv2.resize(cv2.resize(x[0], None, fx=1. / config.SCALE, fy=1. / config.SCALE, interpolation=cv2.INTER_CUBIC), None, fx=1. * config.SCALE, fy=1. * config.SCALE, interpolation=cv2.INTER_CUBIC), (x[0].shape[0], x[0].shape[1], 1))])

    ds = MapData(
        ds, lambda x: [
            np.reshape(
                cv2.resize(x[0],
                           None,
                           fx=1. / config.SCALE,
                           fy=1. / config.SCALE,
                           interpolation=cv2.INTER_CUBIC),
                (50, 50, config.CHANNELS)),
            np.reshape(x[0], (config.INPUT_IMAGE_SIZE, config.INPUT_IMAGE_SIZE,
                              config.CHANNELS)),
            np.reshape(
                cv2.resize(cv2.resize(x[0],
                                      None,
                                      fx=1. / config.SCALE,
                                      fy=1. / config.SCALE,
                                      interpolation=cv2.INTER_CUBIC),
                           None,
                           fx=1. * config.SCALE,
                           fy=1. * config.SCALE,
                           interpolation=cv2.INTER_CUBIC),
                (config.INPUT_IMAGE_SIZE, config.INPUT_IMAGE_SIZE, config.
                 CHANNELS))
        ])
    # print(ds)
    # quit()
    if isTrain:
        ds = MultiProcessRunnerZMQ(ds, config.DATAFLOW_PROC)
        ds = BatchData(ds, config.BATCH_SIZE, remainder=not isTrain)

    return ds