Example #1
0
def get_data(name, meta_dir, gpu_nums):
    isTrain = True if 'train' in name else False

    m = np.array([104, 116, 122])
    const_arr = np.resize(m, (1, 1, 3))  # NCHW
    const_arr = np.zeros(
        (args.crop_size[0], args.crop_size[1], 3)) + const_arr  #broadcast

    if isTrain:
        #ds = FakeData([[1024, 2048, 3], [ 1024, 2048]], 5000, random=False, dtype='uint8')
        #ds = FakeData([[CROP_HEIGHT, CROP_HEIGHT, 3], [CROP_HEIGHT, CROP_HEIGHT]], 5000,random=False, dtype='uint8')
        ds = CityscapesFiles(base_dir, meta_dir, name, shuffle=True)
        parallel = min(3, multiprocessing.cpu_count())
        augmentors = [
            RandomCropWithPadding(args.crop_size),
            Flip(horiz=True),
        ]
        aug = imgaug.AugmentorList(augmentors)

        def mapf(ds):
            img, label = ds
            img = cv2.imread(img, cv2.IMREAD_COLOR)
            label = cv2.imread(label, cv2.IMREAD_GRAYSCALE)
            img, params = aug.augment_return_params(img)
            label = aug._augment(label, params)
            img = img - const_arr  # very time-consuming
            return img, label

        #ds = MapData(ds, mapf)
        ds = MultiThreadMapData(ds,
                                parallel,
                                mapf,
                                buffer_size=500,
                                strict=True)
        #ds = MapData(ds, reduce_mean_rgb)

        ds = BatchData(ds, args.batch_size * gpu_nums)
        #ds = PrefetchDataZMQ(ds, 1)
    else:

        def imgread(ds):
            img, label = ds
            img = cv2.imread(img, cv2.IMREAD_COLOR)
            label = cv2.imread(label, cv2.IMREAD_GRAYSCALE)
            return [img, label]

        ds = CityscapesFiles(base_dir, meta_dir, name, shuffle=False)
        ds = MapData(ds, imgread)
        ds = BatchData(ds, 1)

    return ds
Example #2
0
os.environ['MXNET_ENABLE_GPU_P2P'] = '0'


IGNORE_LABEL = 255

CROP_HEIGHT = 673
CROP_WIDTH = 673
tile_height = 673
tile_width = 673
batch_size = 5 #was 7

EPOCH_SCALE = 9
end_epoch = 10
init_lr = 2.5e-4
lr_step_list = [(2, 2.5e-4), (4,1e-4), (6,1e-5), (10, 8e-6)]
NUM_CLASSES = CityscapesFiles.class_num()
validation_on_last = 2

kvstore = "device"
fixed_param_prefix = ['conv0_weight','stage1','beta','gamma',]


def parse_args():
    parser = argparse.ArgumentParser(description='Train deeplab network')
    parser.add_argument("--gpu", default="1")
    parser.add_argument('--frequent', help='frequency of logging', default=50000, type=int)
    parser.add_argument('--view', action='store_true')
    parser.add_argument("--validation", action="store_true")
    parser.add_argument("--test_speed", action="store_true")
    parser.add_argument("--load", default="tornadomeet-resnet-101-0000")
    parser.add_argument("--scratch", action="store_true" )
Example #3
0
def get_data(name, meta_dir, gpu_nums):
    isTrain = name == 'train'

    def imgread(ds):
        img, label = ds
        img = cv2.imread(img, cv2.IMREAD_COLOR)
        label = cv2.imread(label, cv2.IMREAD_GRAYSCALE)
        return img, label

    if isTrain:
        #ds = LMDBData('/data2/dataset/cityscapes/cityscapes_train.lmdb', shuffle=True)
        #ds = FakeData([[batch_size, CROP_HEIGHT, CROP_HEIGHT, 3], [batch_size, CROP_HEIGHT, CROP_HEIGHT, 1]], 5000, random=False, dtype='uint8')
        ds = CityscapesFiles(meta_dir, name, shuffle=True)
        ds = MultiThreadMapData(ds, 4, imgread)
        #ds = PrefetchDataZMQ(MapData(ds, ImageDecode), 1) #imagedecode is heavy
        ds = MapData(ds, RandomResize)
    else:
        ds = CityscapesFiles(meta_dir, name, shuffle=False)
        ds = MultiThreadMapData(ds, 4, imgread)

    if isTrain:  #special augmentation
        shape_aug = [
            RandomCrop(args.crop_size),
            Flip(horiz=True),
        ]
        ds = AugmentImageComponents(ds, shape_aug, (0, 1), copy=False)

    def MxnetPrepare(ds):
        data, label = ds
        data = np.transpose(data, (0, 3, 1, 2))  # NCHW
        label = label[:, :, :, None]
        label = np.transpose(label, (0, 3, 1, 2))  # NCHW
        dl = [
            [mx.nd.array(data[args.batch_size * i:args.batch_size * (i + 1)])]
            for i in range(gpu_nums)
        ]  # multi-gpu distribute data, time-consuming!!!
        ll = [[
            mx.nd.array(label[args.batch_size * i:args.batch_size * (i + 1)])
        ] for i in range(gpu_nums)]
        return dl, ll

    def reduce_mean_rgb(ds):
        image, label = ds
        m = np.array([104, 116, 122])
        const_arr = np.resize(m, (1, 1, 3))  # NCHW
        image = image - const_arr
        return image, label

    #ds = MapData(ds, reduce_mean_rgb)
    ds = MultiThreadMapData(ds, 4, reduce_mean_rgb)

    if isTrain:
        ds = FastBatchData(ds, args.batch_size * gpu_nums)
        #ds = PrefetchDataZMQ(ds, 1)
        ds = MapData(ds, MxnetPrepare)
        #ds = PrefetchData(ds,100, 1)
        #ds = MultiProcessPrefetchData(ds, 100, 2)
        #ds = PrefetchDataZMQ(MyBatchData(ds, args.batch_size*gpu_nums), 6)
    else:
        ds = BatchData(ds, 1)
    return ds