コード例 #1
0
def get_data(name, data_dir, meta_dir, gpu_nums):
    isTrain = name == 'train'
    ds = Aerial(meta_dir, name, shuffle=True)


    if isTrain:
        ds = MapData(ds, RandomResize)

    if isTrain:
        shape_aug = [
                     RandomCropWithPadding(args.crop_size,IGNORE_LABEL),
                     Flip(horiz=True),
                     ]
    else:
        shape_aug = []

    ds = AugmentImageComponents(ds, shape_aug, (0, 1), copy=False)

    def f(ds):
        image, label = ds
        m = np.array([104, 116, 122])
        const_arr = np.resize(m, (1,1,3))  # NCHW
        image = image - const_arr
        return image, label

    ds = MapData(ds, f)
    if isTrain:
        ds = BatchData(ds, args.batch_size*gpu_nums)
        ds = PrefetchDataZMQ(ds, 1)
    else:
        ds = BatchData(ds, 1)
    return ds
コード例 #2
0
os.environ['MXNET_ENABLE_GPU_P2P'] = '0'


IGNORE_LABEL = 255

CROP_HEIGHT = 473
CROP_WIDTH = 473
tile_height = 513
tile_width = 513

batch_size = 11
EPOCH_SCALE = 4
end_epoch = 9
init_lr = 1e-4
lr_step_list = [(6, 1e-4), (9, 1e-5)]
NUM_CLASSES = Aerial.class_num()
validation_on_last = end_epoch

kvstore = "device"
fixed_param_prefix = ['conv0_weight','stage1','beta','gamma',]
from symbol_resnet_deeplabv2 import resnet101_deeplab_new

def parse_args():
    parser = argparse.ArgumentParser(description='Train deeplab network')
    # training
    parser.add_argument("--gpu", default="3")
    parser.add_argument('--frequent', help='frequency of logging', default=1000, type=int)
    parser.add_argument('--view', action='store_true')
    parser.add_argument("--validation", action="store_true")
    parser.add_argument("--load", default="../tornadomeet-resnet-101-0000")
    parser.add_argument("--scratch", action="store_true" )