Exemplo n.º 1
0
def get_data(name, data_dir, meta_dir, gpu_nums):
    isTrain = True if 'train' in name else False
    ds = Camvid(data_dir, meta_dir, name, shuffle=True)


    if isTrain:
        ds = MapData(ds, RandomResize)

    if isTrain:
        shape_aug = [
                     RandomCropWithPadding(args.crop_size,IGNORE_LABEL),
                     Flip(horiz=True),
                     ]
    else:
        shape_aug = []

    ds = AugmentImageComponents(ds, shape_aug, (0, 1), copy=False)

    def f(ds):
        image, label = ds
        m = np.array([104, 116, 122])
        const_arr = np.resize(m, (1,1,3))  # NCHW
        image = image - const_arr
        return image, label

    ds = MapData(ds, f)
    if isTrain:
        ds = BatchData(ds, args.batch_size*gpu_nums)
        ds = PrefetchDataZMQ(ds, 1)
    else:
        ds = BatchData(ds, 1)
    return ds
Exemplo n.º 2
0
def get_data(name, data_dir, meta_dir, gpu_nums):
    isTrain = True if 'train' in name else False
    ds = PascalVOC12(data_dir, meta_dir, name, shuffle=True)

    if isTrain:
        shape_aug = [
            RandomResize(xrange=(0.7, 1.5),
                         yrange=(0.7, 1.5),
                         aspect_ratio_thres=0.15,
                         interp=cv2.INTER_NEAREST),
            RandomCropWithPadding(args.crop_size, IGNORE_LABEL),
            Flip(horiz=True),
        ]
    else:
        shape_aug = []

    ds = AugmentImageComponents(ds, shape_aug, (0, 1), copy=False)

    def f(ds):
        image, label = ds
        m = np.array([104, 116, 122])
        const_arr = np.resize(m, (1, 1, 3))  # NCHW
        image = image - const_arr
        return image, label

    ds = MapData(ds, f)
    if isTrain:
        ds = BatchData(ds, args.batch_size * gpu_nums)
        ds = PrefetchDataZMQ(ds, 1)
    else:
        ds = BatchData(ds, 1)
    return ds