示例#1
0
def main():
    fcnxs = symbol_fcnxs.get_fcn32s_symbol(numclass=2, workspace_default=1536)
    fcnxs_model_prefix = "model_pascal/FCN32s_VGG16"
    num_epoch = 1
    learning_rate = 1e-3
    if args.model == "fcn16s":
        fcnxs = symbol_fcnxs.get_fcn16s_symbol(numclass=2, workspace_default=1536)
        fcnxs_model_prefix = "model_pascal/FCN16s_VGG16"
        learning_rate = 1e-5
    elif args.model == "fcn8s":
        fcnxs = symbol_fcnxs.get_fcn8s_symbol(numclass=2, workspace_default=1536)
        fcnxs_model_prefix = "model_pascal/FCN8s_VGG16"
        num_epoch = 30
        learning_rate = 1e-7
    arg_names = fcnxs.list_arguments()
    _, fcnxs_args, fcnxs_auxs = mx.model.load_checkpoint(args.prefix, args.epoch)
    if not args.retrain:
        if args.init_type == "vgg16":
            fcnxs_args, fcnxs_auxs = init_fcnxs.init_from_vgg16(ctx, fcnxs, fcnxs_args, fcnxs_auxs)
        elif args.init_type == "fcnxs":
            fcnxs_args, fcnxs_auxs = init_fcnxs.init_from_fcnxs(ctx, fcnxs, fcnxs_args, fcnxs_auxs)
    train_dataiter = FileIter(
        root_dir             = ".",
        flist_name           = "train.lst",
        # cut_off_size         = 400,
        rgb_mean             = (123.68, 116.779, 103.939),
        )
#     val_dataiter = FileIter(
#         root_dir             = "/home/zw/dataset/VOC2012Segmentation/VOC2012",
#         flist_name           = "val.lst",
#         rgb_mean             = (123.68, 116.779, 103.939),
#         )
    model = Solver(
        ctx                 = ctx,
        symbol              = fcnxs,
        begin_epoch         = 0,
        num_epoch           = num_epoch,
        arg_params          = fcnxs_args,
        aux_params          = fcnxs_auxs,
        learning_rate       = learning_rate,
        momentum            = 0.9,
        wd                  = 0.0001)
    model.fit(
        train_data          = train_dataiter,
#         eval_data           = val_dataiter,
        batch_end_callback  = mx.callback.Speedometer(1, 10),
        epoch_end_callback  = mx.callback.do_checkpoint(fcnxs_model_prefix))
示例#2
0
def main():
    ctx = mx.cpu() if not args.gpu else mx.gpu(args.gpu)
    fcnxs = symbol_fcnxs.get_fcn32s_symbol(numclass=21, workspace_default=1536)
    fcnxs_model_prefix = "model_pascal/FCN32s_VGG16"
    if args.model == "fcn16s":
        fcnxs = symbol_fcnxs.get_fcn16s_symbol(numclass=21, workspace_default=1536)
        fcnxs_model_prefix = "model_pascal/FCN16s_VGG16"
    elif args.model == "fcn8s":
        fcnxs = symbol_fcnxs.get_fcn8s_symbol(numclass=21, workspace_default=1536)
        fcnxs_model_prefix = "model_pascal/FCN8s_VGG16"
    arg_names = fcnxs.list_arguments()
    _, fcnxs_args, fcnxs_auxs = mx.model.load_checkpoint(args.prefix, args.epoch)
    if not args.retrain:
        if args.init_type == "vgg16":
            fcnxs_args, fcnxs_auxs = init_fcnxs.init_from_vgg16(ctx, fcnxs, fcnxs_args, fcnxs_auxs)
        elif args.init_type == "fcnxs":
            fcnxs_args, fcnxs_auxs = init_fcnxs.init_from_fcnxs(ctx, fcnxs, fcnxs_args, fcnxs_auxs)
    train_dataiter = FileIter(
        root_dir             = "./VOC2012",
        flist_name           = "train.lst",
        # cut_off_size         = 400,
        rgb_mean             = (123.68, 116.779, 103.939),
        )
    val_dataiter = FileIter(
        root_dir             = "./VOC2012",
        flist_name           = "val.lst",
        rgb_mean             = (123.68, 116.779, 103.939),
        )
    model = Solver(
        ctx                 = ctx,
        symbol              = fcnxs,
        begin_epoch         = 0,
        num_epoch           = 50,
        arg_params          = fcnxs_args,
        aux_params          = fcnxs_auxs,
        learning_rate       = 1e-10,
        momentum            = 0.99,
        wd                  = 0.0005)
    model.fit(
        train_data          = train_dataiter,
        eval_data           = val_dataiter,
        batch_end_callback  = mx.callback.Speedometer(1, 10),
        epoch_end_callback  = mx.callback.do_checkpoint(fcnxs_model_prefix))
示例#3
0
def main():
    gpu_list = []
    _gpus = args.gpu.split(',')
    for _gpu in _gpus:
      _gpu = _gpu.strip()
      if len(_gpu)==0:
        continue
      gpu_list.append(int(_gpu))
    assert len(gpu_list)>0
    ctx = mx.gpu(gpu_list[0])
    carvn_root = ''
    num_classes = 2
    cutoff = None if args.cutoff==0 else args.cutoff
    epochs = [74,31,27,19]
    model_prefixes = ['VGG_FC_ILSVRC_16_layers', args.model_dir+"/FCN32s_VGG16", args.model_dir+"/FCN16s_VGG16", args.model_dir+"/FCN8s_VGG16"]
    if args.model == "fcn16s":
      fcnxs = symbol_fcnxs.get_fcn16s_symbol(numclass=num_classes, workspace_default=1536)
      fcnxs_model_prefix = model_prefixes[2]
      load_prefix = model_prefixes[1]
      lr = 1e-5
      run_epochs = epochs[2]
      load_epoch = epochs[1]
    elif args.model == "fcn8s":
      fcnxs = symbol_fcnxs.get_fcn8s_symbol(numclass=num_classes, workspace_default=1536)
      fcnxs_model_prefix = model_prefixes[3]
      load_prefix = model_prefixes[2]
      lr = 1e-6
      run_epochs = epochs[3]
      load_epoch = epochs[2]
    else:
      fcnxs = symbol_fcnxs.get_fcn32s_symbol(numclass=num_classes, workspace_default=1536)
      fcnxs_model_prefix = model_prefixes[1]
      load_prefix = model_prefixes[0]
      lr = 1e-4
      run_epochs = epochs[1]
      load_epoch = epochs[0]
    arg_names = fcnxs.list_arguments()
    print('loading', load_prefix, load_epoch)
    print('lr', lr)
    print('model_prefix', fcnxs_model_prefix)
    print('running epochs', run_epochs)
    _, fcnxs_args, fcnxs_auxs = mx.model.load_checkpoint(load_prefix, load_epoch)
    if not args.retrain:
      if args.model == "fcn16s" or args.model == "fcn8s":
        fcnxs_args, fcnxs_auxs = init_fcnxs.init_from_fcnxs(ctx, fcnxs, fcnxs_args, fcnxs_auxs)
      else:
        fcnxs_args, fcnxs_auxs = init_fcnxs.init_from_vgg16(ctx, fcnxs, fcnxs_args, fcnxs_auxs)
    train_dataiter = FileIter(
        root_dir             = carvn_root,
        flist_name           = "../data/train.lst",
        cut_off_size         = cutoff,
        rgb_mean             = (123.68, 116.779, 103.939),
        )
    val_dataiter = FileIter(
        root_dir             = carvn_root,
        flist_name           = "../data/val.lst",
        cut_off_size         = cutoff,
        rgb_mean             = (123.68, 116.779, 103.939),
        )
    model = Solver(
        ctx                 = ctx,
        symbol              = fcnxs,
        begin_epoch         = 0,
        num_epoch           = run_epochs,
        arg_params          = fcnxs_args,
        aux_params          = fcnxs_auxs,
        learning_rate       = lr,
        momentum            = 0.99,
        wd                  = 0.0005)
    _metric = DiceMetric()
    model.fit(
        train_data          = train_dataiter,
        eval_data           = val_dataiter,
        eval_metric         = _metric,
        batch_end_callback  = mx.callback.Speedometer(1, 10),
        epoch_end_callback  = mx.callback.do_checkpoint(fcnxs_model_prefix))
示例#4
0
def main():
    # region 0. 准备模型
    # 旧模型,用于pre train 模型
    old_model_root_dir = './model_coco_person'
    old_batch_size = 8
    old_learning_rate = 1e-6
    pre_train_model_epoch = 50
    pre_train_model_type = 'FCN32s'
    pre_train_model_prefix = "%s/%s_VGG16_size%d_batch%d_lr%.0e" % (
        old_model_root_dir,
        pre_train_model_type,
        resize_size[0],
        old_batch_size,
        old_learning_rate,
    )
    # 新模型前缀
    # model_coco_person, model_pascal_person
    new_model_root_dir = './model_coco_person'
    new_batch_size = 8
    new_learning_rate = 1e-7
    fcnxs_new_model_prefix = "%s/%s_VGG16_data%s_size%d_batch%d_lr%.0e" % (
        new_model_root_dir,
        args.model,
        data_type,
        resize_size[0],
        new_batch_size,
        new_learning_rate,
    )
    begin_epoch = pre_train_model_epoch
    # if not continue_train:
    #     begin_epoch = 0
    logging.info('model prefix: %s' % fcnxs_new_model_prefix)
    logging.info('new_learning_rate: %.0e' % new_learning_rate)
    logging.info('batch_size: %d' % new_batch_size)
    logging.info('resize_size: %s' % str(resize_size))
    # endregion

    # if not continue_train:
    # region 1. 构建模型
    if args.model == 'FCN32s':
        fcnxs = symbol_fcnxs.get_fcn32s_symbol(numclass=2,
                                               workspace_default=2048)
    elif args.model == "FCN16s":
        fcnxs = symbol_fcnxs.get_fcn16s_symbol(numclass=2,
                                               workspace_default=1536)
    elif args.model == "FCN8s":
        fcnxs = symbol_fcnxs.get_fcn8s_symbol(numclass=2,
                                              workspace_default=1536)
    elif args.model == "FCN4s":
        fcnxs = symbol_fcnxs.get_fcn4s_symbol(numclass=2,
                                              workspace_default=1536)
    elif args.model == "FCN_atrous":
        fcnxs = symbol_fcnxs_atrous_person.get_fcnatrous_symbol(
            numclass=2, workspace_default=1536)
    else:
        raise NotImplementedError
    # endregion

    # region 2. 加载 pre-trained 的VGG16模型 并初始化 FCN模型
    logging.info('pre train with %s---000%d' %
                 (pre_train_model_prefix, pre_train_model_epoch))
    _, fcnxs_args, fcnxs_auxs = mx.model.load_checkpoint(
        pre_train_model_prefix, pre_train_model_epoch)
    # print fcnxs_args.keys()
    # pre train FCN模型
    if args.init_type == "vgg16":
        fcnxs_args, fcnxs_auxs = init_fcnxs.init_from_vgg16(
            ctx, fcnxs, fcnxs_args, fcnxs_auxs)
    elif args.init_type == "fcnxs":
        fcnxs_args, fcnxs_auxs = init_fcnxs.init_from_fcnxs(
            ctx, fcnxs, fcnxs_args, fcnxs_auxs)
    else:
        raise NotImplementedError
    # endregion
    # region 准备训练和验证数据
    train_dataiter, val_dataiter = get_train_val_iter(
        use_record_data=use_record_data,
        root_dir=data_root_dir,
        resize_size=resize_size,
        batch_size=new_batch_size,
        # cut_off_size         = 400,
        rgb_mean=(123.68, 116.779, 103.939),
        buffer_image_set=buffer_image_set,
        args={
            'data_train': data_train,
            'data_val': data_val,
            'image_shape': (3, resize_size[0], resize_size[1]),
            'rgb_mean': (123.68, 116.779, 103.939),
            'batch_size': new_batch_size,
            'data_nthreads': 50
        })
    # quit()
    # endregion
    # region 开始训练
    model = Solver(ctx=ctx,
                   symbol=fcnxs,
                   begin_epoch=begin_epoch,
                   num_epoch=args.epoch,
                   arg_params=fcnxs_args,
                   aux_params=fcnxs_auxs,
                   learning_rate=new_learning_rate,
                   momentum=0.99,
                   wd=0.0005)
    model.fit(
        train_data=train_dataiter,
        eval_data=val_dataiter,
        period=period,
        to_eval_train=to_eval_train,
        eval_metric=eval_metric,
        batch_end_callback=mx.callback.Speedometer(batch_size=new_batch_size,
                                                   frequent=speedometer_freq,
                                                   auto_reset=False),
        epoch_end_callback=mx.callback.do_checkpoint(fcnxs_new_model_prefix,
                                                     period=5))