Exemplo n.º 1
0
 def __init__(self, symbol, ctx=None,
              begin_epoch=0, num_epoch=None,
              arg_params=None, aux_params=None,
              valid_metric=MApMetric(),
              class_names=[],
              optimizer='sgd', **kwargs):
     self.symbol = symbol
     if ctx is None:
         ctx = mx.cpu(0)
     self.ctx = ctx
     self.begin_epoch = begin_epoch
     self.num_epoch = num_epoch
     self.arg_params = arg_params
     self.aux_params = aux_params
     self.valid_metric = valid_metric
     self.class_names = class_names
     self.optimizer = optimizer
     self.evaluation_only = False
     self.kwargs = kwargs.copy()
Exemplo n.º 2
0
def train_multitask(netname,
                    train_path,
                    num_classes,
                    batch_size,
                    data_shape,
                    mean_pixels,
                    resume,
                    finetune,
                    pretrained,
                    epoch,
                    prefix,
                    ctx,
                    begin_epoch,
                    end_epoch,
                    frequent,
                    learning_rate,
                    momentum,
                    weight_decay,
                    lr_refactor_step,
                    lr_refactor_ratio,
                    freeze_layer_pattern='',
                    num_example=10000,
                    label_pad_width=350,
                    nms_thresh=0.45,
                    force_nms=False,
                    ovp_thresh=0.5,
                    use_difficult=False,
                    class_names=None,
                    voc07_metric=False,
                    nms_topk=400,
                    force_suppress=False,
                    train_list="",
                    val_path="",
                    val_list="",
                    iter_monitor=0,
                    monitor_pattern=".*",
                    log_file=None):
    """
    Wrapper for training phase.

    Parameters:
    ----------
    netname : str
        symbol name for the network structure
    train_path : str
        record file path for training
    num_classes : int
        number of object classes, not including background
    batch_size : int
        training batch-size
    data_shape : int or tuple
        width/height as integer or (3, height, width) tuple
    mean_pixels : tuple of floats
        mean pixel values for red, green and blue
    resume : int
        resume from previous checkpoint if > 0
    finetune : int
        fine-tune from previous checkpoint if > 0
    pretrained : str
        prefix of pretrained model, including path
    epoch : int
        load epoch of either resume/finetune/pretrained model
    prefix : str
        prefix for saving checkpoints
    ctx : [mx.cpu()] or [mx.gpu(x)]
        list of mxnet contexts
    begin_epoch : int
        starting epoch for training, should be 0 if not otherwise specified
    end_epoch : int
        end epoch of training
    frequent : int
        frequency to print out training status
    learning_rate : float
        training learning rate
    momentum : float
        trainig momentum
    weight_decay : float
        training weight decay param
    lr_refactor_ratio : float
        multiplier for reducing learning rate
    lr_refactor_step : comma separated integers
        at which epoch to rescale learning rate, e.g. '30, 60, 90'
    freeze_layer_pattern : str
        regex pattern for layers need to be fixed
    num_example : int
        number of training images
    label_pad_width : int
        force padding training and validation labels to sync their label widths
    nms_thresh : float
        non-maximum suppression threshold for validation
    force_nms : boolean
        suppress overlaped objects from different classes
    train_list : str
        list file path for training, this will replace the embeded labels in record
    val_path : str
        record file path for validation
    val_list : str
        list file path for validation, this will replace the embeded labels in record
    iter_monitor : int
        monitor internal stats in networks if > 0, specified by monitor_pattern
    monitor_pattern : str
        regex pattern for monitoring network stats
    log_file : str
        log to file if enabled
    """
    # set up logger
    logger = logging.getLogger()
    fh = logging.FileHandler(
        os.path.join(
            'log',
            time.strftime('%F-%T', time.localtime()).replace(':', '-') +
            '.log'))
    fh.setLevel(logging.DEBUG)
    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)
    logger.addHandler(fh)
    logger.addHandler(ch)

    # logging.basicConfig()
    # logger = logging.getLogger()
    # logger.setLevel(logging.INFO)
    # if log_file:
    #     fh = logging.FileHandler(log_file)
    #     logger.addHandler(fh)

    # check args
    if isinstance(data_shape, int):
        data_shape = (3, data_shape, data_shape)
    assert len(data_shape) == 3 and data_shape[0] == 3
    prefix += '_' + netname + '_' + str(data_shape[1])

    if isinstance(mean_pixels, (int, float)):
        mean_pixels = [mean_pixels, mean_pixels, mean_pixels]
    assert len(mean_pixels) == 3, "must provide all RGB mean values"

    logger.info(
        str({
            "train_path": train_path,
            "batch_size": batch_size,
            "data_shape": data_shape
        }))
    train_iter = MultiTaskRecordIter(train_path,
                                     batch_size,
                                     data_shape,
                                     mean_pixels=mean_pixels,
                                     label_pad_width=label_pad_width,
                                     path_imglist=train_list,
                                     enable_aug=True,
                                     **cfg.train)

    if val_path:
        val_iter = MultiTaskRecordIter(val_path,
                                       batch_size,
                                       data_shape,
                                       mean_pixels=mean_pixels,
                                       label_pad_width=label_pad_width,
                                       path_imglist=val_list,
                                       enable_aug=False,
                                       **cfg.valid)
    else:
        val_iter = None

    # load symbol
    logger.info(
        str({
            "num_classes": num_classes,
            "nms_thresh": nms_thresh,
            "force_suppress": force_suppress,
            "nms_topk": nms_topk
        }))
    if netname in ["resnet-18", "resnet-50"]:
        net = get_fcn32s_symbol_train(netname,
                                      data_shape[1],
                                      num_classes=num_classes,
                                      nms_thresh=nms_thresh,
                                      force_suppress=force_suppress,
                                      nms_topk=nms_topk)
    elif netname.endswith("det"):
        net = get_det_symbol_train(netname.split("_")[0],
                                   data_shape[1],
                                   num_classes=num_classes,
                                   nms_thresh=nms_thresh,
                                   force_suppress=force_suppress,
                                   nms_topk=nms_topk)
    elif netname.endswith("seg"):
        net = get_seg_symbol_train(netname.split("_")[0],
                                   data_shape[1],
                                   num_classes=num_classes,
                                   nms_thresh=nms_thresh,
                                   force_suppress=force_suppress,
                                   nms_topk=nms_topk)
    elif netname.endswith("multi"):
        net = get_multi_symbol_train(netname.split("_")[0],
                                     data_shape[1],
                                     num_classes=num_classes,
                                     nms_thresh=nms_thresh,
                                     force_suppress=force_suppress,
                                     nms_topk=nms_topk)
    else:
        raise NotImplementedError("")

    ################# analyze shapes #######################
    # arg_shapes, out_shapes, aux_shapes = net.infer_shape(data=(1,3,512,1024), label_det=(1,200,6))
    # arg_names = net.list_arguments()
    # print([(n,s) for n,s in zip(arg_names,arg_shapes)])

    # define layers with fixed weight/bias
    if freeze_layer_pattern.strip():
        re_prog = re.compile(freeze_layer_pattern)
        fixed_param_names = [
            name for name in net.list_arguments() if re_prog.match(name)
        ]
    else:
        fixed_param_names = None

    # load pretrained or resume from previous state
    ctx_str = '(' + ','.join([str(c) for c in ctx]) + ')'
    ctx = ctx[0]
    if resume > 0:
        logger.info("Resume training with {} from epoch {}".format(
            ctx_str, resume))
        _, args, auxs = mx.model.load_checkpoint(prefix, resume)
        begin_epoch = resume
        args = {key: val.as_in_context(ctx) for key, val in args.items()}
        auxs = {key: val.as_in_context(ctx) for key, val in auxs.items()}
    # elif finetune > 0:
    #     logger.info("Start finetuning with {} from epoch {}".format(ctx_str, finetune))
    #     _, args, auxs = mx.model.load_checkpoint(prefix, finetune)
    #     begin_epoch = finetune
    #     # the prediction convolution layers name starts with relu, so it's fine
    #     fixed_param_names = [name for name in net.list_arguments() if name.startswith('conv')]
    elif pretrained:
        logger.info("Start training with {} from pretrained model {}".format(
            ctx_str, pretrained))
        _, args, auxs = mx.model.load_checkpoint(pretrained, epoch)
        # args = convert_pretrained(pretrained, args)
        args = {key: val.as_in_context(ctx) for key, val in args.items()}
        auxs = {key: val.as_in_context(ctx) for key, val in auxs.items()}
        args, auxs = init_from_resnet(ctx, net, args, auxs)
    else:
        logger.info("Experimental: start training from scratch with {}".format(
            ctx_str))
        args, auxs, fixed_param_names = None, None, None

    # helper information
    if fixed_param_names:
        logger.info("Freezed parameters: [" + ','.join(fixed_param_names) +
                    ']')

    # init training module
    logger.info("Creating Module ...")
    mod = mx.mod.Module(net,
                        label_names=(
                            'label_det',
                            'seg_out_label',
                        ),
                        logger=logger,
                        context=ctx,
                        fixed_param_names=fixed_param_names)

    # fit parameters
    batch_end_callback = mx.callback.Speedometer(train_iter.batch_size,
                                                 frequent=frequent)
    epoch_end_callback = mx.callback.do_checkpoint(prefix)
    learning_rate, lr_scheduler = get_lr_scheduler(learning_rate,
                                                   lr_refactor_step,
                                                   lr_refactor_ratio,
                                                   num_example, batch_size,
                                                   begin_epoch)
    optimizer_params = {
        'learning_rate': learning_rate,
        'momentum': momentum,
        'wd': weight_decay,
        'lr_scheduler': lr_scheduler,
        'clip_gradient': None,
        'rescale_grad': 1.0
    }  #/ len(ctx) if len(ctx) > 0 else 1.0 }
    monitor = mx.mon.Monitor(
        iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None

    # run fit net, every n epochs we run evaluation network to get mAP
    valid_metric = MApMetric(ovp_thresh,
                             use_difficult,
                             class_names,
                             pred_idx=0)

    from pprint import pprint
    import numpy as np
    import cv2
    from palette import color2index, index2color

    pprint(optimizer_params)
    np.set_printoptions(formatter={"float": lambda x: "%.3f " % x},
                        suppress=True)

    ############### uncomment the following lines to visualize network ###########################
    # dot = mx.viz.plot_network(net, shape={'data':(1,3,512,1024),"label_det":(1,200,6)})
    # dot.view()

    ############### uncomment the following lines to visualize data ###########################
    # data_batch, _ = train_iter.next()
    # pprint({"data":data_batch.data[0].shape,
    #         "label_det":data_batch.label[0].shape,
    #         "seg_out_label":data_batch.label[1].shape})
    # data = data_batch.data[0].asnumpy()
    # label = data_batch.label[0].asnumpy()
    # segmt = data_batch.label[1].asnumpy()
    # for ii in range(data.shape[0]):
    #     img = data[ii,:,:,:]
    #     seg = segmt[ii,:,:]
    #     print label[ii,:5,:]
    #     img = np.squeeze(img)
    #     img = np.swapaxes(img, 0, 2)
    #     img = np.swapaxes(img, 0, 1)
    #     img = (img + np.array([123.68, 116.779, 103.939]).reshape((1,1,3))).astype(np.uint8)
    #     img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    #     rois = label[ii,:,:]
    #     hh, ww, ch = img.shape
    #     for lidx in range(rois.shape[0]):
    #         roi = rois[lidx,:]
    #         if roi[0]>=0:
    #             cv2.rectangle(img, (int(roi[1]*ww),int(roi[2]*hh)), (int(roi[3]*ww),int(roi[4]*hh)), (0,0,128))
    #             cls_id = int(roi[0])
    #             bbox = [int(roi[1]*ww),int(roi[2]*hh),int(roi[3]*ww),int(roi[4]*hh)]
    #             text = '%s %.0fm' % (class_names[cls_id], roi[5]*255.)
    #             putText(img,bbox,text)
    #     disp = np.zeros((hh*2, ww, ch),np.uint8)
    #     disp[:hh,:, :] = img.astype(np.uint8)
    #     disp[hh:,:, :] = cv2.resize(index2color(seg),(ww,hh),interpolation=cv2.INTER_NEAREST)
    #     cv2.imshow("img", disp)
    #     if cv2.waitKey()&0xff==27: exit(0)

    # ctx=ctx[0]
    # args = {key: val.as_in_context(ctx) for key, val in args.items()}
    # auxs = {key: val.as_in_context(ctx) for key, val in auxs.items()}
    # args, auxs = init_from_resnet(ctx, net, args, auxs)

    pprint({"ctx":ctx,"begin_epoch":begin_epoch,"end_epoch":end_epoch, \
           "learning_rate":learning_rate,"momentum":momentum})

    model = None
    if netname.endswith("multi"):
        model = MultiTaskSolver(
            ctx=ctx,
            symbol=net,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch,  # 50 epoch
            arg_params=args,
            aux_params=auxs,
            learning_rate=learning_rate,  # 1e-5
            lr_scheduler=lr_scheduler,
            momentum=momentum,  # 0.99
            wd=0.0005,  # 0.0005
            valid_metric=valid_metric,
            class_names=class_names,
        )
    elif netname.endswith("det"):
        model = DetTaskSolver(
            ctx=ctx,
            symbol=net,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch,  # 50 epoch
            arg_params=args,
            aux_params=auxs,
            learning_rate=learning_rate,  # 1e-5
            lr_scheduler=lr_scheduler,
            momentum=momentum,  # 0.99
            wd=0.0005,  # 0.0005
            valid_metric=valid_metric,
            class_names=class_names,
        )
    elif netname.endswith("seg"):
        model = SegTaskSolver(
            ctx=ctx,
            symbol=net,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch,  # 50 epoch
            arg_params=args,
            aux_params=auxs,
            learning_rate=learning_rate,  # 1e-5
            lr_scheduler=lr_scheduler,
            momentum=momentum,  # 0.99
            wd=0.0005,  # 0.0005
            valid_metric=valid_metric,
            class_names=class_names,
        )
    else:
        raise NotImplementedError("")

    model.fit(train_data=train_iter,
              eval_data=val_iter,
              batch_end_callback=batch_end_callback,
              epoch_end_callback=epoch_end_callback)
Exemplo n.º 3
0
def train_net(net, train_path, num_classes, batch_size,
              data_shape, mean_pixels, resume, finetune, pretrained, epoch,
              prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate,
              momentum, weight_decay, lr_refactor_step, lr_refactor_ratio,
              freeze_layer_pattern='',
              num_example=10000, label_pad_width=350,
              nms_thresh=0.45, force_nms=False, ovp_thresh=0.5,
              use_difficult=False, class_names=None,
              voc07_metric=False, nms_topk=400, force_suppress=False,
              train_list="", val_path="", val_list="", iter_monitor=0,
              monitor_pattern=".*", log_file=None):
    """
    Wrapper for training phase.

    Parameters:
    ----------
    net : str
        symbol name for the network structure
    train_path : str
        record file path for training
    num_classes : int
        number of object classes, not including background
    batch_size : int
        training batch-size
    data_shape : int or tuple
        width/height as integer or (3, height, width) tuple
    mean_pixels : tuple of floats
        mean pixel values for red, green and blue
    resume : int
        resume from previous checkpoint if > 0
    finetune : int
        fine-tune from previous checkpoint if > 0
    pretrained : str
        prefix of pretrained model, including path
    epoch : int
        load epoch of either resume/finetune/pretrained model
    prefix : str
        prefix for saving checkpoints
    ctx : [mx.cpu()] or [mx.gpu(x)]
        list of mxnet contexts
    begin_epoch : int
        starting epoch for training, should be 0 if not otherwise specified
    end_epoch : int
        end epoch of training
    frequent : int
        frequency to print out training status
    learning_rate : float
        training learning rate
    momentum : float
        trainig momentum
    weight_decay : float
        training weight decay param
    lr_refactor_ratio : float
        multiplier for reducing learning rate
    lr_refactor_step : comma separated integers
        at which epoch to rescale learning rate, e.g. '30, 60, 90'
    freeze_layer_pattern : str
        regex pattern for layers need to be fixed
    num_example : int
        number of training images
    label_pad_width : int
        force padding training and validation labels to sync their label widths
    nms_thresh : float
        non-maximum suppression threshold for validation
    force_nms : boolean
        suppress overlaped objects from different classes
    train_list : str
        list file path for training, this will replace the embeded labels in record
    val_path : str
        record file path for validation
    val_list : str
        list file path for validation, this will replace the embeded labels in record
    iter_monitor : int
        monitor internal stats in networks if > 0, specified by monitor_pattern
    monitor_pattern : str
        regex pattern for monitoring network stats
    log_file : str
        log to file if enabled
    """
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    if log_file:
        fh = logging.FileHandler(log_file)
        logger.addHandler(fh)

    # check args
    if isinstance(data_shape, int):
        data_shape = (3, data_shape, data_shape)
    assert len(data_shape) == 3 and data_shape[0] == 3
    if prefix.endswith('_'):
        prefix += '_' + str(data_shape[1])

    if isinstance(mean_pixels, (int, float)):
        mean_pixels = [mean_pixels, mean_pixels, mean_pixels]
    assert len(mean_pixels) == 3, "must provide all RGB mean values"

    train_iter = DetRecordIter(train_path, batch_size, data_shape, mean_pixels=mean_pixels,
        label_pad_width=label_pad_width, path_imglist=train_list, **cfg.train)

    if val_path:
        val_iter = DetRecordIter(val_path, batch_size, data_shape, mean_pixels=mean_pixels,
            label_pad_width=label_pad_width, path_imglist=val_list, **cfg.valid)
    else:
        val_iter = None

    # load symbol
    net = get_symbol_train(net, data_shape[1], num_classes=num_classes,
        nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk)

    # define layers with fixed weight/bias
    if freeze_layer_pattern.strip():
        re_prog = re.compile(freeze_layer_pattern)
        fixed_param_names = [name for name in net.list_arguments() if re_prog.match(name)]
    else:
        fixed_param_names = None

    # load pretrained or resume from previous state
    ctx_str = '('+ ','.join([str(c) for c in ctx]) + ')'
    if resume > 0:
        logger.info("Resume training with {} from epoch {}"
            .format(ctx_str, resume))
        _, args, auxs = mx.model.load_checkpoint(prefix, resume)
        begin_epoch = resume
    elif finetune > 0:
        logger.info("Start finetuning with {} from epoch {}"
            .format(ctx_str, finetune))
        _, args, auxs = mx.model.load_checkpoint(prefix, finetune)
        begin_epoch = finetune
        # check what layers mismatch with the loaded parameters
        exe = net.simple_bind(mx.cpu(), data=(1, 3, 300, 300), label=(1, 1, 5), grad_req='null')
        arg_dict = exe.arg_dict
        fixed_param_names = []
        for k, v in arg_dict.items():
            if k in args:
                if v.shape != args[k].shape:
                    del args[k]
                    logging.info("Removed %s" % k)
                else:
                    if not 'pred' in k:
                        fixed_param_names.append(k)
    elif pretrained:
        logger.info("Start training with {} from pretrained model {}"
            .format(ctx_str, pretrained))
        _, args, auxs = mx.model.load_checkpoint(pretrained, epoch)
        args = convert_pretrained(pretrained, args)
    else:
        logger.info("Experimental: start training from scratch with {}"
            .format(ctx_str))
        args = None
        auxs = None
        fixed_param_names = None

    # helper information
    if fixed_param_names:
        logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']')

    # init training module
    mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx,
                        fixed_param_names=fixed_param_names)

    # fit parameters
    batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent)
    epoch_end_callback = mx.callback.do_checkpoint(prefix)
    learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step,
        lr_refactor_ratio, num_example, batch_size, begin_epoch)
    optimizer_params={'learning_rate':learning_rate,
                      'momentum':momentum,
                      'wd':weight_decay,
                      'lr_scheduler':lr_scheduler,
                      'clip_gradient':None,
                      'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0 }
    monitor = mx.mon.Monitor(iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None

    # run fit net, every n epochs we run evaluation network to get mAP
    if voc07_metric:
        valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3)
    else:
        valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3)

    mod.fit(train_iter,
            val_iter,
            eval_metric=MultiBoxMetric(),
            validation_metric=valid_metric,
            batch_end_callback=batch_end_callback,
            epoch_end_callback=epoch_end_callback,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch,
            initializer=mx.init.Xavier(),
            arg_params=args,
            aux_params=auxs,
            allow_missing=True,
            monitor=monitor)
Exemplo n.º 4
0
def evaluate_net(net,
                 path_imgrec,
                 num_classes,
                 mean_pixels,
                 data_shape,
                 model_prefix,
                 epoch,
                 ctx=mx.cpu(),
                 batch_size=1,
                 path_imglist="",
                 nms_thresh=0.45,
                 force_nms=False,
                 ovp_thresh=0.5,
                 use_difficult=False,
                 class_names=None,
                 voc07_metric=False):
    """
    evalute network given validation record file

    Parameters:
    ----------
    net : str or None
        Network name or use None to load from json without modifying
    path_imgrec : str
        path to the record validation file
    path_imglist : str
        path to the list file to replace labels in record file, optional
    num_classes : int
        number of classes, not including background
    mean_pixels : tuple
        (mean_r, mean_g, mean_b)
    data_shape : tuple or int
        (3, height, width) or height/width
    model_prefix : str
        model prefix of saved checkpoint
    epoch : int
        load model epoch
    ctx : mx.ctx
        mx.gpu() or mx.cpu()
    batch_size : int
        validation batch size
    nms_thresh : float
        non-maximum suppression threshold
    force_nms : boolean
        whether suppress different class objects
    ovp_thresh : float
        AP overlap threshold for true/false postives
    use_difficult : boolean
        whether to use difficult objects in evaluation if applicable
    class_names : comma separated str
        class names in string, must correspond to num_classes if set
    voc07_metric : boolean
        whether to use 11-point evluation as in VOC07 competition
    """
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # args
    if isinstance(data_shape, int):
        data_shape = (3, data_shape, data_shape)
    assert len(data_shape) == 3 and data_shape[0] == 3
    model_prefix += '_' + str(data_shape[1])

    # iterator
    eval_iter = DetRecordIter(path_imgrec,
                              batch_size,
                              data_shape,
                              mean_pixels=mean_pixels,
                              path_imglist=path_imglist,
                              **cfg.valid)
    # model params
    load_net, args, auxs = mx.model.load_checkpoint(model_prefix, epoch)
    # network
    if net is None:
        net = load_net
    else:
        net = get_symbol(net,
                         data_shape[1],
                         num_classes=num_classes,
                         nms_thresh=nms_thresh,
                         force_suppress=force_nms)
    if not 'label' in net.list_arguments():
        label = mx.sym.Variable(name='label')
        net = mx.sym.Group([net, label])

    # init module
    mod = mx.mod.Module(net,
                        label_names=('label', ),
                        logger=logger,
                        context=ctx,
                        fixed_param_names=net.list_arguments())
    mod.bind(data_shapes=eval_iter.provide_data,
             label_shapes=eval_iter.provide_label)
    mod.set_params(args, auxs, allow_missing=False, force_init=True)

    # run evaluation
    if voc07_metric:
        metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names)
    else:
        metric = MApMetric(ovp_thresh, use_difficult, class_names)
    results = mod.score(eval_iter, metric, num_batch=None)
    for k, v in results:
        print("{}: {}".format(k, v))
Exemplo n.º 5
0
def evaluate_net(net,
                 path_imgrec,
                 num_classes,
                 mean_pixels,
                 data_shape,
                 model_prefix,
                 epoch,
                 ctx=mx.cpu(),
                 batch_size=1,
                 path_imglist="",
                 nms_thresh=0.45,
                 force_nms=False,
                 ovp_thresh=0.5,
                 use_difficult=False,
                 class_names=None,
                 voc07_metric=False,
                 frequent=20):
    """
    evalute network given validation record file

    Parameters:
    ----------
    net : str or None
        Network name or use None to load from json without modifying
    path_imgrec : str
        path to the record validation file
    path_imglist : str
        path to the list file to replace labels in record file, optional
    num_classes : int
        number of classes, not including background
    mean_pixels : tuple
        (mean_r, mean_g, mean_b)
    data_shape : tuple or int
        (3, height, width) or height/width
    model_prefix : str
        model prefix of saved checkpoint
    epoch : int
        load model epoch
    ctx : mx.ctx
        mx.gpu() or mx.cpu()
    batch_size : int
        validation batch size
    nms_thresh : float
        non-maximum suppression threshold
    force_nms : boolean
        whether suppress different class objects
    ovp_thresh : float
        AP overlap threshold for true/false postives
    use_difficult : boolean
        whether to use difficult objects in evaluation if applicable
    class_names : comma separated str
        class names in string, must correspond to num_classes if set
    voc07_metric : boolean
        whether to use 11-point evluation as in VOC07 competition
    frequent : int
        frequency to print out validation status
    """
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # args
    if isinstance(data_shape, int):
        data_shape = (3, data_shape, data_shape)
    assert len(data_shape) == 3 and data_shape[0] == 3
    #model_prefix += '_' + str(data_shape[1])

    # iterator
    eval_iter = DetRecordIter(path_imgrec,
                              batch_size,
                              data_shape,
                              path_imglist=path_imglist,
                              **cfg.valid)
    # model params
    load_net, args, auxs = mx.model.load_checkpoint(model_prefix, epoch)
    # network
    if net is None:
        net = load_net
    else:
        net = get_symbol(net,
                         data_shape[1],
                         num_classes=num_classes,
                         nms_thresh=nms_thresh,
                         force_suppress=force_nms)

    if not 'label' in net.list_arguments():
        label = mx.sym.Variable(name='label')
        net = mx.sym.Group([net, label])

    data_shape = (1, 3, 300, 300)
    mx.viz.plot_network_detail(net,
                               shape={
                                   "data": data_shape,
                                   "label": (1, 1, 5)
                               },
                               node_attrs={
                                   "hide_weights": "true",
                                   "fixedsize": 'false',
                                   "shape": 'oval'
                               }).view()

    # Gang Chen add
    exe = net.simple_bind(mx.cpu(),
                          data=(1, 3, 300, 300),
                          label=(1, 1, 5),
                          grad_req='null')
    arg_dict = exe.arg_dict

    for k, v in args.items():
        if k not in arg_dict:
            del args[k]

    # END Gang Chen add

    # init module
    mod = mx.mod.Module(net,
                        label_names=('label', ),
                        logger=logger,
                        context=ctx,
                        fixed_param_names=net.list_arguments())
    mod.bind(data_shapes=eval_iter.provide_data,
             label_shapes=eval_iter.provide_label)
    mod.set_params(args, auxs, allow_missing=True, force_init=True)

    # run evaluation
    if voc07_metric:
        metric = VOC07MApMetric(ovp_thresh,
                                use_difficult,
                                class_names,
                                roc_output_path=os.path.join(
                                    os.path.dirname(model_prefix), 'roc'))
    else:
        metric = MApMetric(ovp_thresh,
                           use_difficult,
                           class_names,
                           roc_output_path=os.path.join(
                               os.path.dirname(model_prefix), 'roc'))

    results = mod.score(eval_iter,
                        metric,
                        num_batch=None,
                        batch_end_callback=mx.callback.Speedometer(
                            batch_size, frequent=frequent, auto_reset=False))

    for k, v in results:
        print("{}: {}".format(k, v))
Exemplo n.º 6
0
def train_net_common(net, train_iter, val_iter, batch_size,
                     data_shape, resume, finetune, pretrained, epoch,
                     prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate,
                     momentum, weight_decay, use_plateau, lr_refactor_step, lr_refactor_ratio,
                     freeze_layer_pattern='',
                     num_example=10000, label_pad_width=350,
                     nms_thresh=0.45, force_suppress=False, ovp_thresh=0.5,
                     use_difficult=False, class_names=None,
                     optimizer_name='sgd',
                     voc07_metric=False, nms_topk=400,
                     iter_monitor=0,
                     monitor_pattern=".*", logger=None):
    """
    """
    # check args
    prefix += '_' + net + '_' + str(data_shape[1])

    # load symbol
    net_str = net
    net = get_symbol_train(net, data_shape[1], num_classes=len(class_names),
        nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk)

    # define layers with fixed weight/bias
    if freeze_layer_pattern.strip():
        re_prog = re.compile(freeze_layer_pattern)
        fixed_param_names = [name for name in net.list_arguments() if re_prog.match(name)]
    else:
        fixed_param_names = None

    # load pretrained or resume from previous state
    ctx_str = '('+ ','.join([str(c) for c in ctx]) + ')'
    if resume > 0:
        logger.info("Resume training with {} from epoch {}"
            .format(ctx_str, resume))
        _, args, auxs = mx.model.load_checkpoint(prefix, resume)
        begin_epoch = resume
    elif finetune > 0:
        logger.info("Start finetuning with {} from epoch {}"
            .format(ctx_str, finetune))
        _, args, auxs = mx.model.load_checkpoint(prefix, finetune)
        begin_epoch = finetune
        # the prediction convolution layers name starts with relu, so it's fine
        fixed_param_names = [name for name in net.list_arguments() \
            if name.startswith('conv')]
    elif pretrained:
        try:
            logger.info("Start training with {} from pretrained model {}"
                .format(ctx_str, pretrained))
            _, args, auxs = mx.model.load_checkpoint(pretrained, epoch)
            args = convert_pretrained(pretrained, args)
            if net_str == 'ssd_pva':
                args, auxs = convert_pvanet(args, auxs)
        except:
            logger.info("Failed to load the pretrained model. Start from scratch.")
            args = None
            auxs = None
            fixed_param_names = None
    else:
        logger.info("Experimental: start training from scratch with {}"
            .format(ctx_str))
        args = None
        auxs = None
        fixed_param_names = None

    # helper information
    if fixed_param_names:
        logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']')

    # init training module
    if not use_plateau: # focal loss does not go well with plateau
        mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx,
                fixed_param_names=fixed_param_names)
    else:
        mod = PlateauModule(net, label_names=('label',), logger=logger, context=ctx,
                fixed_param_names=fixed_param_names)

    # robust parameter setting
    mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
    mod = set_mod_params(mod, args, auxs, logger)

    # fit parameters
    batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent, auto_reset=True)
    epoch_end_callback = mx.callback.do_checkpoint(prefix)
    monitor = mx.mon.Monitor(iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None
    optimizer_params={'learning_rate': learning_rate,
                      'wd': weight_decay,
                      'clip_gradient': 4.0,
                      'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0 }
    if optimizer_name == 'sgd':
        optimizer_params['momentum'] = momentum

    # #7847
    mod.init_optimizer(optimizer=optimizer_name, optimizer_params=optimizer_params, force_init=True)

    if not use_plateau:
        learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step,
                lr_refactor_ratio, num_example, batch_size, begin_epoch)
    else:
        w_l1 = cfg.train['smoothl1_weight']
        eval_weights = {'CrossEntropy': 1.0, 'SmoothL1': w_l1}
        plateau_lr = PlateauScheduler( \
                patient_epochs=lr_refactor_step, factor=float(lr_refactor_ratio), eval_weights=eval_weights)
        plateau_metric = MultiBoxMetric(fn_stat='/home/hyunjoon/github/additions_mxnet/ssd/stat.txt')

    eval_metric = MultiBoxMetric()
    # run fit net, every n epochs we run evaluation network to get mAP
    if voc07_metric:
        valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=4)
    else:
        valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=4)

    if not use_plateau:
        mod.fit(train_iter,
                eval_data=val_iter,
                eval_metric=eval_metric,
                validation_metric=valid_metric,
                batch_end_callback=batch_end_callback,
                epoch_end_callback=epoch_end_callback,
                optimizer=optimizer_name,
                optimizer_params=optimizer_params,
                begin_epoch=begin_epoch,
                num_epoch=end_epoch,
                initializer=mx.init.Xavier(),
                arg_params=args,
                aux_params=auxs,
                allow_missing=True,
                monitor=monitor)
    else:
        mod.fit(train_iter,
                plateau_lr, plateau_metric=plateau_metric,
                fn_curr_model=prefix+'-1000.params',
                plateau_backtrace=False,
                eval_data=val_iter,
                eval_metric=eval_metric,
                validation_metric=valid_metric,
                validation_period=5,
                batch_end_callback=batch_end_callback,
                epoch_end_callback=epoch_end_callback,
                optimizer=optimizer_name,
                optimizer_params=optimizer_params,
                begin_epoch=begin_epoch,
                num_epoch=end_epoch,
                initializer=mx.init.Xavier(),
                arg_params=args,
                aux_params=auxs,
                allow_missing=True,
                monitor=monitor)
Exemplo n.º 7
0
def train_net(net,
              train_path,
              num_classes,
              batch_size,
              data_shape,
              mean_img,
              mean_img_dir,
              resume,
              finetune,
              pretrained,
              epoch,
              prefix,
              ctx,
              begin_epoch,
              end_epoch,
              frequent,
              learning_rate,
              momentum,
              weight_decay,
              lr_refactor_step,
              lr_refactor_ratio,
              convert_numpy=1,
              freeze_layer_pattern='',
              num_example=10000,
              label_pad_width=350,
              nms_thresh=0.45,
              force_nms=False,
              ovp_thresh=0.5,
              use_difficult=False,
              class_names=None,
              voc07_metric=False,
              nms_topk=400,
              force_suppress=False,
              train_list="",
              val_path="",
              val_list="",
              iter_monitor=0,
              monitor_pattern=".*",
              log_file=None,
              summarywriter=0,
              flush_secs=180):
    """
    Wrapper for training phase.

    Parameters:
    ----------
    net : str
        symbol name for the network structure
    train_path : str
        record file path for training
    num_classes : int
        number of object classes, not including background
    batch_size : int
        training batch-size
    data_shape : int or tuple
        width/height as integer or (3, height, width) tuple
    mean_pixels : tuple of floats
        mean pixel values for red, green and blue
    resume : int
        resume from previous checkpoint if > 0
    finetune : int
        fine-tune from previous checkpoint if > 0
    pretrained : str
        prefix of pretrained model, including path
    epoch : int
        load epoch of either resume/finetune/pretrained model
    prefix : str
        prefix for saving checkpoints
    ctx : [mx.cpu()] or [mx.gpu(x)]
        list of mxnet contexts
    begin_epoch : int
        starting epoch for training, should be 0 if not otherwise specified
    end_epoch : int
        end epoch of training
    frequent : int
        frequency to print out training status
    learning_rate : float
        training learning rate
    momentum : float
        trainig momentum
    weight_decay : float
        training weight decay param
    lr_refactor_ratio : float
        multiplier for reducing learning rate
    lr_refactor_step : comma separated integers
        at which epoch to rescale learning rate, e.g. '30, 60, 90'
    freeze_layer_pattern : str
        regex pattern for layers need to be fixed
    num_example : int
        number of training images
    label_pad_width : int
        force padding training and validation labels to sync their label widths
    nms_thresh : float
        non-maximum suppression threshold for validation
    force_nms : boolean
        suppress overlaped objects from different classes
    train_list : str
        list file path for training, this will replace the embeded labels in record
    val_path : str
        record file path for validation
    val_list : str
        list file path for validation, this will replace the embeded labels in record
    iter_monitor : int
        monitor internal stats in networks if > 0, specified by monitor_pattern
    monitor_pattern : str
        regex pattern for monitoring network stats
    log_file : str
        log to file if enabled
    """
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    if log_file:
        fh = logging.FileHandler(log_file)
        logger.addHandler(fh)

    # check args
    if isinstance(data_shape, int):
        data_shape = (3, data_shape, data_shape)
    assert len(data_shape) == 3 and data_shape[0] == 3
    prefix += '_' + net + '_' + str(data_shape[1])

    # if isinstance(mean_pixels, (int, float)):
    #     mean_pixels = [mean_pixels, mean_pixels, mean_pixels]
    # assert len(mean_pixels) == 3, "must provide all RGB mean values"

    train_iter = DetRecordIter(train_path,
                               batch_size,
                               data_shape,
                               mean_img=mean_img,
                               label_pad_width=label_pad_width,
                               path_imglist=train_list,
                               **cfg.train)

    if val_path:
        val_iter = DetRecordIter(val_path,
                                 batch_size,
                                 data_shape,
                                 mean_img=mean_img,
                                 label_pad_width=label_pad_width,
                                 path_imglist=val_list,
                                 **cfg.valid)
    else:
        val_iter = None

    # convert mean.bin to mean.npy
    _convert_mean_numpy(convert_numpy, mean_img_dir, mean_img)

    # load symbol
    net = get_symbol_train(net,
                           data_shape[1],
                           num_classes=num_classes,
                           nms_thresh=nms_thresh,
                           force_suppress=force_suppress,
                           nms_topk=nms_topk)

    if summarywriter:
        if os.path.exists('/opt/incubator-mxnet/example/ssd/logs'):
            shutil.rmtree('/opt/incubator-mxnet/example/ssd/logs'
                          )  # clear the previous logs
        os.mkdir('/opt/incubator-mxnet/example/ssd/logs')
        sw = SummaryWriter(logdir='/opt/incubator-mxnet/example/ssd/logs',
                           flush_secs=flush_secs)
        sw.add_graph(net)
    else:
        sw = None
    # mx.viz.plot_network(net, shape={"data":(64, 3, 320, 320)}, node_attrs={"shape":'rect',"fixedsize":'false'}).view()
    # define layers with fixed weight/bias
    if freeze_layer_pattern.strip():
        re_prog = re.compile(freeze_layer_pattern)
        fixed_param_names = [
            name for name in net.list_arguments() if re_prog.match(name)
        ]
    else:
        fixed_param_names = None

    # load pretrained or resume from previous state
    ctx_str = '(' + ','.join([str(c) for c in ctx]) + ')'
    if resume > 0:
        logger.info("Resume training with {} from epoch {}".format(
            ctx_str, resume))
        _, args, auxs = mx.model.load_checkpoint(prefix, resume)
        begin_epoch = resume
    elif finetune > 0:
        logger.info("Start finetuning with {} from epoch {}".format(
            ctx_str, finetune))
        _, args, auxs = mx.model.load_checkpoint(prefix, finetune)
        begin_epoch = finetune
        # the prediction convolution layers name starts with relu, so it's fine
        fixed_param_names = [name for name in net.list_arguments() \
            if name.startswith('conv')]
    elif pretrained:
        logger.info("Start training with {} from pretrained model {}".format(
            ctx_str, pretrained))
        _, args, auxs = mx.model.load_checkpoint(pretrained, epoch)
        args = convert_pretrained(pretrained, args)
    else:
        logger.info("Experimental: start training from scratch with {}".format(
            ctx_str))
        args = None
        auxs = None
        fixed_param_names = None

    # helper information
    if fixed_param_names:
        logger.info("Freezed parameters: [" + ','.join(fixed_param_names) +
                    ']')

    # init training module
    mod = mx.mod.Module(net,
                        label_names=('label', ),
                        logger=logger,
                        context=ctx,
                        fixed_param_names=fixed_param_names)

    # fit parameters

    if summarywriter:
        # 增加可视化的回调函数,有多个回调函数时,除最后一个回调函数外不能进行准确率的清零操作(即auto_reset参数必须设置为False)
        batch_end_callbacks = [
            mx.callback.Speedometer(train_iter.batch_size,
                                    frequent=frequent,
                                    auto_reset=True),
            summary_writter_callback.summary_writter_eval_metric(sw)
        ]
    else:
        batch_end_callbacks = [
            mx.callback.Speedometer(train_iter.batch_size,
                                    frequent=frequent,
                                    auto_reset=False)
        ]
    # batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent)

    epoch_end_callback = mx.callback.do_checkpoint(prefix)
    learning_rate, lr_scheduler = get_lr_scheduler(learning_rate,
                                                   lr_refactor_step,
                                                   lr_refactor_ratio,
                                                   num_example, batch_size,
                                                   begin_epoch)
    optimizer_params = {
        'learning_rate': learning_rate,
        'momentum': momentum,
        'wd': weight_decay,
        'lr_scheduler': lr_scheduler,
        'clip_gradient': None,
        'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0
    }
    monitor = mx.mon.Monitor(
        iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None

    # run fit net, every n epochs we run evaluation network to get mAP
    if voc07_metric:
        valid_metric = VOC07MApMetric(ovp_thresh,
                                      use_difficult,
                                      class_names,
                                      pred_idx=3)
    else:
        valid_metric = MApMetric(ovp_thresh,
                                 use_difficult,
                                 class_names,
                                 pred_idx=3)

    mod.fit(train_iter,
            val_iter,
            eval_metric=MultiBoxMetric(),
            validation_metric=valid_metric,
            batch_end_callback=batch_end_callbacks,
            epoch_end_callback=epoch_end_callback,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch,
            initializer=mx.init.Xavier(),
            arg_params=args,
            aux_params=auxs,
            allow_missing=True,
            monitor=monitor)
    if summarywriter:
        sw.close()
Exemplo n.º 8
0
def evaluate(netname,
             path_imgrec,
             num_classes,
             num_seg_classes,
             mean_pixels,
             data_shape,
             model_prefix,
             epoch,
             ctx=mx.cpu(),
             batch_size=1,
             path_imglist="",
             nms_thresh=0.45,
             force_nms=False,
             ovp_thresh=0.5,
             use_difficult=False,
             class_names=None,
             seg_class_names=None,
             voc07_metric=False):
    """
    evalute network given validation record file

    Parameters:
    ----------
    net : str or None
        Network name or use None to load from json without modifying
    path_imgrec : str
        path to the record validation file
    path_imglist : str
        path to the list file to replace labels in record file, optional
    num_classes : int
        number of classes, not including background
    mean_pixels : tuple
        (mean_r, mean_g, mean_b)
    data_shape : tuple or int
        (3, height, width) or height/width
    model_prefix : str
        model prefix of saved checkpoint
    epoch : int
        load model epoch
    ctx : mx.ctx
        mx.gpu() or mx.cpu()
    batch_size : int
        validation batch size
    nms_thresh : float
        non-maximum suppression threshold
    force_nms : boolean
        whether suppress different class objects
    ovp_thresh : float
        AP overlap threshold for true/false postives
    use_difficult : boolean
        whether to use difficult objects in evaluation if applicable
    class_names : comma separated str
        class names in string, must correspond to num_classes if set
    voc07_metric : boolean
        whether to use 11-point evluation as in VOC07 competition
    """
    global outimgiter

    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # args
    if isinstance(data_shape, int):
        data_shape = (3, data_shape, data_shape)
    else:
        data_shape = map(int, data_shape.split(","))
    assert len(data_shape) == 3 and data_shape[0] == 3
    model_prefix += '_' + str(data_shape[1])

    # iterator
    eval_iter = MultiTaskRecordIter(path_imgrec,
                                    batch_size,
                                    data_shape,
                                    path_imglist=path_imglist,
                                    enable_aug=False,
                                    **cfg.valid)
    # model params
    load_net, args, auxs = mx.model.load_checkpoint(model_prefix, epoch)
    # network
    if netname is None:
        net = load_net
    elif netname.endswith("det"):
        net = get_det_symbol(netname.split("_")[0],
                             data_shape[1],
                             num_classes=num_classes,
                             nms_thresh=nms_thresh,
                             force_suppress=force_nms)
    elif netname.endswith("seg"):
        net = get_seg_symbol(netname.split("_")[0],
                             data_shape[1],
                             num_classes=num_classes,
                             nms_thresh=nms_thresh,
                             force_suppress=force_nms)
    elif netname.endswith("multi"):
        net = get_multi_symbol(netname.split("_")[0],
                               data_shape[1],
                               num_classes=num_classes,
                               nms_thresh=nms_thresh,
                               force_suppress=force_nms)
    else:
        raise NotImplementedError("")

    if not 'label_det' in net.list_arguments():
        label_det = mx.sym.Variable(name='label_det')
        net = mx.sym.Group([net, label_det])
    if not 'seg_out_label' in net.list_arguments():
        seg_out_label = mx.sym.Variable(name='seg_out_label')
        net = mx.sym.Group([net, seg_out_label])

    # init module
    # mod = mx.mod.Module(net, label_names=('label_det','seg_out_label',), logger=logger, context=ctx,
    #     fixed_param_names=net.list_arguments())
    # mod.bind(data_shapes=eval_iter.provide_data, label_shapes=eval_iter.provide_label)
    # mod.set_params(args, auxs, allow_missing=False, force_init=True)
    # metric = MApMetric(ovp_thresh, use_difficult, class_names)
    # results = mod.score(eval_iter, metric, num_batch=None)
    # for k, v in results:
    #     print("{}: {}".format(k, v))

    ctx = ctx[0]
    eval_metric = CustomAccuracyMetric()
    multibox_metric = MultiBoxMetric()
    depth_metric = DistanceAccuracyMetric(class_names=class_names)
    det_metric = MApMetric(ovp_thresh, use_difficult, class_names)
    seg_metric = IoUMetric(class_names=seg_class_names, axis=1)
    eval_metrics = metric.CompositeEvalMetric()
    eval_metrics.add(multibox_metric)
    eval_metrics.add(eval_metric)
    arg_params = {key: val.as_in_context(ctx) for key, val in args.items()}
    aux_params = {key: val.as_in_context(ctx) for key, val in auxs.items()}
    data_name = eval_iter.provide_data[0][0]
    label_name_det = eval_iter.provide_label[0][0]
    label_name_seg = eval_iter.provide_label[1][0]
    symbol = load_net

    # evaluation
    logger.info(" in eval process...")
    logger.info(
        str({
            "ovp_thresh": ovp_thresh,
            "nms_thresh": nms_thresh,
            "batch_size": batch_size,
            "force_nms": force_nms,
        }))
    nbatch = 0
    eval_iter.reset()
    eval_metrics.reset()
    det_metric.reset()
    total_time = 0

    for data, fnames in eval_iter:
        nbatch += 1
        label_shape_det = data.label[0].shape
        label_shape_seg = data.label[1].shape
        arg_params[data_name] = mx.nd.array(data.data[0], ctx)
        arg_params[label_name_det] = mx.nd.array(data.label[0], ctx)
        arg_params[label_name_seg] = mx.nd.array(data.label[1], ctx)
        executor = symbol.bind(ctx, arg_params, aux_states=aux_params)

        output_names = symbol.list_outputs()
        output_dict = dict(zip(output_names, executor.outputs))

        cpu_output_array = mx.nd.zeros(output_dict["seg_out_output"].shape)

        ############## monitor status
        def stat_helper(name, array):
            """wrapper for executor callback"""
            import ctypes
            from mxnet.ndarray import NDArray
            from mxnet.base import NDArrayHandle, py_str
            array = ctypes.cast(array, NDArrayHandle)
            if 1:
                array = NDArray(array, writable=False).asnumpy()
                print(name, array.shape, np.mean(array), np.std(array),
                      ('%.1fms' %
                       (float(time.time() - stat_helper.start_time) * 1000)))
            else:
                array = NDArray(array, writable=False)
                array.wait_to_read()
                elapsed = float(time.time() - stat_helper.start_time) * 1000.
                if elapsed > 5:
                    print(name, array.shape, ('%.1fms' % (elapsed, )))
            stat_helper.start_time = time.time()

        stat_helper.start_time = float(time.time())
        # executor.set_monitor_callback(stat_helper)

        ############## forward
        tic = time.time()
        executor.forward(is_train=True)
        output_dict["seg_out_output"].copyto(cpu_output_array)
        pred_shape = output_dict["seg_out_output"].shape
        label = mx.nd.array(data.label[1].reshape(
            (label_shape_seg[0], label_shape_seg[1] * label_shape_seg[2])))
        output_dict["seg_out_output"].wait_to_read()

        toc = time.time()

        seg_out_output = output_dict["seg_out_output"].asnumpy()

        pred_seg_shape = output_dict["seg_out_output"].shape
        label_det = mx.nd.array(data.label[0].reshape(
            (label_shape_det[0], label_shape_det[1] * label_shape_det[2])))
        label_seg = mx.nd.array(data.label[1].reshape(
            (label_shape_seg[0], label_shape_seg[1] * label_shape_seg[2])),
                                ctx=ctx)
        pred_seg = mx.nd.array(output_dict["seg_out_output"].reshape(
            (pred_seg_shape[0], pred_seg_shape[1],
             pred_seg_shape[2] * pred_seg_shape[3])),
                               ctx=ctx)
        #### remove invalid boxes
        out_det = output_dict["det_out_output"].asnumpy()
        indices = np.where(out_det[:, :, 0] >= 0)  # labeled as negative
        out_det = np.expand_dims(out_det[indices[0], indices[1], :], axis=0)
        indices = np.where(out_det[:, :, 1] > .1)  # higher confidence
        out_det = np.expand_dims(out_det[indices[0], indices[1], :], axis=0)
        # indices = np.where(out_det[:,:,6]<=(100/255.)) # too far away
        # out_det = np.expand_dims(out_det[indices[0],indices[1],:],axis=0)
        pred_det = mx.nd.array(out_det)
        #### remove labels too faraway
        # label_det = label_det.asnumpy().reshape((200,6))
        # indices = np.where(label_det[:,5]<=(100./255.))
        # label_det = np.expand_dims(label_det[indices[0],:],axis=0)
        # label_det = mx.nd.array(label_det)

        ################# display results ####################
        out_img = output_dict["seg_out_output"]
        out_img = mx.nd.split(out_img, axis=0, num_outputs=out_img.shape[0])
        for imgidx in range(batch_size):
            seg_prob = out_img[imgidx]
            res_img = np.squeeze(seg_prob.asnumpy().argmax(axis=0).astype(
                np.uint8))
            label_img = data.label[1].asnumpy()[imgidx, :, :].astype(np.uint8)
            img = np.squeeze(data.data[0].asnumpy()[imgidx, :, :, :])
            det = out_det[imgidx, :, :]
            gt = label_det.asnumpy()[imgidx, :].reshape((-1, 6))
            # save to results folder for evalutation
            res_fname = fnames[imgidx].replace("SegmentationClass", "results")
            lut = np.zeros(256)
            lut[:19] = np.array([
                7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
                31, 32, 33
            ])
            seg_resized = prob_upsampling(seg_prob, target_shape=(1024, 2048))
            seg_resized2 = cv2.LUT(seg_resized, lut)
            # seg = cv2.LUT(res_img,lut)
            # cv2.imshow("seg",seg.astype(np.uint8))
            cv2.imwrite(res_fname, seg_resized2)
            # display result
            print(fnames[imgidx], np.average(img))
            display_img = display_results(res_img,
                                          np.expand_dims(label_img, axis=0),
                                          img, det, gt, class_names)
            res_fname = fnames[imgidx].replace("SegmentationClass",
                                               "output").replace(
                                                   "labelTrainIds", "output")
            cv2.imwrite(res_fname, display_img)
            [exit(0) if (cv2.waitKey() & 0xff) == 27 else None]
        outimgiter += 1
        ################# display results ####################

        eval_metrics.get_metric(0).update(None, [
            output_dict["cls_prob_output"], output_dict["loc_loss_output"],
            output_dict["cls_label_output"]
        ])
        eval_metrics.get_metric(1).update([label_seg], [pred_seg])
        det_metric.update([mx.nd.slice_axis(data.label[0],axis=2,begin=0,end=5)], \
                                 [mx.nd.slice_axis(pred_det,axis=2,begin=0,end=6)])
        seg_metric.update([label_seg], [pred_seg])
        disparities = []
        for imgidx in range(batch_size):
            dispname = fnames[imgidx].replace("SegmentationClass",
                                              "Disparity").replace(
                                                  "gtFine_labelTrainIds",
                                                  "disparity")
            print(dispname)
            disparities.append(cv2.imread(dispname, -1))
        depth_metric.update(mx.nd.array(disparities), [pred_det])

        det_names, det_values = det_metric.get()
        seg_names, seg_values = seg_metric.get()
        depth_names, depth_values = depth_metric.get()
        total_time += toc - tic
        print("\r %d/%d %.1f%% speed=%.1fms %s=%.1f %s=%.1f %s=%.1f" % (
            nbatch * eval_iter.batch_size,
            eval_iter.num_samples,
            float(nbatch * eval_iter.batch_size) * 100. /
            float(eval_iter.num_samples),
            total_time * 1000. / nbatch,
            det_names[-1],
            det_values[-1] * 100.,
            seg_names[-1],
            seg_values[-1] * 100.,
            depth_names[-1],
            depth_values[-1] * 100.,
        ),
              end='\r')

        # if nbatch>50: break ## debugging

    names, values = eval_metrics.get()
    for name, value in zip(names, values):
        logger.info(' epoch[%d] Validation-%s=%f', epoch, name, value)
    logger.info('----------------------------------------------')
    names, values = det_metric.get()
    for name, value in zip(names, values):
        logger.info(' epoch[%d] Validation-%s=%f', epoch, name, value)
    logger.info('----------------------------------------------')
    logger.info(' & '.join(names))
    logger.info(' & '.join(map(lambda v: '%.1f' % (v * 100., ), values)))
    logger.info('----------------------------------------------')
    names, values = depth_metric.get()
    for name, value in zip(names, values):
        logger.info(' epoch[%d] Validation-%s=%f', epoch, name, value)
    logger.info('----------------------------------------------')
    logger.info(' & '.join(names))
    logger.info(' & '.join(map(lambda v: '%.1f' % (v * 100., ), values)))
    logger.info('----------------------------------------------')
    names, values = seg_metric.get()
    for name, value in zip(names, values):
        logger.info(' epoch[%d] Validation-%s=%f', epoch, name, value)
    logger.info('----------------------------------------------')
    logger.info(' & '.join(names))
    logger.info(' & '.join(map(lambda v: '%.1f' % (v * 100., ), values)))
Exemplo n.º 9
0
def train_net(net,
              train_path,
              num_classes,
              batch_size,
              data_shape,
              mean_pixels,
              resume,
              finetune,
              pretrained,
              epoch,
              prefix,
              ctx,
              begin_epoch,
              end_epoch,
              frequent,
              learning_rate,
              momentum,
              weight_decay,
              use_plateau,
              lr_refactor_step,
              lr_refactor_ratio,
              use_global_stats=0,
              freeze_layer_pattern='',
              num_example=10000,
              label_pad_width=350,
              nms_thresh=0.45,
              force_nms=False,
              ovp_thresh=0.5,
              use_difficult=False,
              class_names=None,
              ignore_names=None,
              optimizer_name='sgd',
              voc07_metric=False,
              nms_topk=400,
              force_suppress=False,
              train_list="",
              val_path="",
              val_list="",
              iter_monitor=0,
              monitor_pattern=".*",
              log_file=None):
    """
    Wrapper for training phase.

    Parameters:
    ----------
    net : str
        symbol name for the network structure
    train_path : str
        record file path for training
    num_classes : int
        number of object classes, not including background
    batch_size : int
        training batch-size
    data_shape : int or tuple
        width/height as integer or (3, height, width) tuple
    mean_pixels : tuple of floats
        mean pixel values for red, green and blue
    resume : int
        resume from previous checkpoint if > 0
    finetune : int
        fine-tune from previous checkpoint if > 0
    pretrained : str
        prefix of pretrained model, including path
    epoch : int
        load epoch of either resume/finetune/pretrained model
    prefix : str
        prefix for saving checkpoints
    ctx : [mx.cpu()] or [mx.gpu(x)]
        list of mxnet contexts
    begin_epoch : int
        starting epoch for training, should be 0 if not otherwise specified
    end_epoch : int
        end epoch of training
    frequent : int
        frequency to print out training status
    learning_rate : float
        training learning rate
    momentum : float
        trainig momentum
    weight_decay : float
        training weight decay param
    lr_refactor_ratio : float
        multiplier for reducing learning rate
    lr_refactor_step : comma separated integers
        at which epoch to rescale learning rate, e.g. '30, 60, 90'
    freeze_layer_pattern : str
        regex pattern for layers need to be fixed
    num_example : int
        number of training images
    label_pad_width : int
        force padding training and validation labels to sync their label widths
    nms_thresh : float
        non-maximum suppression threshold for validation
    force_nms : boolean
        suppress overlaped objects from different classes
    train_list : str
        list file path for training, this will replace the embeded labels in record
    val_path : str
        record file path for validation
    val_list : str
        list file path for validation, this will replace the embeded labels in record
    iter_monitor : int
        monitor internal stats in networks if > 0, specified by monitor_pattern
    monitor_pattern : str
        regex pattern for monitoring network stats
    log_file : str
        log to file if enabled
    """
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    if log_file:
        fh = logging.FileHandler(log_file)
        logger.addHandler(fh)

    # check args
    if isinstance(data_shape, int):
        data_shape = (3, data_shape, data_shape)
    assert len(data_shape) == 3 and data_shape[0] == 3
    prefix += '_' + net + '_' + str(data_shape[1])

    if isinstance(mean_pixels, (int, float)):
        mean_pixels = [mean_pixels, mean_pixels, mean_pixels]
    assert len(mean_pixels) == 3, "must provide all RGB mean values"

    train_iter = DetRecordIter(train_path,
                               batch_size,
                               data_shape,
                               mean_pixels=mean_pixels,
                               label_pad_width=label_pad_width,
                               path_imglist=train_list,
                               **cfg.train)

    if val_path:
        val_iter = DetRecordIter(val_path,
                                 batch_size,
                                 data_shape,
                                 mean_pixels=mean_pixels,
                                 label_pad_width=label_pad_width,
                                 path_imglist=val_list,
                                 **cfg.valid)
    else:
        val_iter = None

    # load symbol
    net_str = net
    net = get_symbol_train(net, data_shape[1], \
            use_global_stats=use_global_stats, \
            num_classes=num_classes, ignore_names=ignore_names, \
            nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk)

    # define layers with fixed weight/bias
    if freeze_layer_pattern.strip():
        re_prog = re.compile(freeze_layer_pattern)
        fixed_param_names = [
            name for name in net.list_arguments() if re_prog.match(name)
        ]
    else:
        fixed_param_names = None

    # load pretrained or resume from previous state
    ctx_str = '(' + ','.join([str(c) for c in ctx]) + ')'
    if resume > 0:
        logger.info("Resume training with {} from epoch {}".format(
            ctx_str, resume))
        _, args, auxs = mx.model.load_checkpoint(prefix, resume)
        begin_epoch = resume
    elif finetune > 0:
        logger.info("Start finetuning with {} from epoch {}".format(
            ctx_str, finetune))
        _, args, auxs = mx.model.load_checkpoint(prefix, finetune)
        begin_epoch = finetune
        # the prediction convolution layers name starts with relu, so it's fine
        fixed_param_names = [name for name in net.list_arguments() \
            if name.startswith('conv')]
    elif pretrained:
        try:
            logger.info(
                "Start training with {} from pretrained model {}".format(
                    ctx_str, pretrained))
            _, args, auxs = mx.model.load_checkpoint(pretrained, epoch)
            args = convert_pretrained(pretrained, args)
            if net_str == 'ssd_pva':
                args, auxs = convert_pvanet(args, auxs)
        except:
            logger.info(
                "Failed to load the pretrained model. Start from scratch.")
            args = None
            auxs = None
            fixed_param_names = None
    else:
        logger.info("Experimental: start training from scratch with {}".format(
            ctx_str))
        args = None
        auxs = None
        fixed_param_names = None

    # helper information
    if fixed_param_names:
        logger.info("Freezed parameters: [" + ','.join(fixed_param_names) +
                    ']')

    # init training module
    if not use_plateau:  # focal loss does not go well with plateau
        mod = mx.mod.Module(net,
                            label_names=('label', ),
                            logger=logger,
                            context=ctx,
                            fixed_param_names=fixed_param_names)
    else:
        mod = PlateauModule(net,
                            label_names=('label', ),
                            logger=logger,
                            context=ctx,
                            fixed_param_names=fixed_param_names)

    # robust parameter setting
    mod.bind(data_shapes=train_iter.provide_data,
             label_shapes=train_iter.provide_label)
    mod = set_mod_params(mod, args, auxs, logger)

    # fit parameters
    batch_end_callback = mx.callback.Speedometer(train_iter.batch_size,
                                                 frequent=frequent,
                                                 auto_reset=True)
    epoch_end_callback = mx.callback.do_checkpoint(prefix)
    monitor = mx.mon.Monitor(
        iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None
    optimizer_params = {
        'learning_rate': learning_rate,
        'wd': weight_decay,
        'clip_gradient': 4.0,
        'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0
    }
    if optimizer_name == 'sgd':
        optimizer_params['momentum'] = momentum

    # #7847
    mod.init_optimizer(optimizer=optimizer_name,
                       optimizer_params=optimizer_params,
                       force_init=True)

    if not use_plateau:
        learning_rate, lr_scheduler = get_lr_scheduler(learning_rate,
                                                       lr_refactor_step,
                                                       lr_refactor_ratio,
                                                       num_example, batch_size,
                                                       begin_epoch)
    else:
        w_l1 = cfg.train['smoothl1_weight']
        eval_weights = {
            'CrossEntropy': 1.0,
            'SmoothL1': w_l1,
            'ObjectRecall': 0.0
        }
        plateau_lr = PlateauScheduler( \
                patient_epochs=lr_refactor_step, factor=float(lr_refactor_ratio), eval_weights=eval_weights)
        plateau_metric = MultiBoxMetric(
            fn_stat='/home/hyunjoon/github/additions_mxnet/ssd/stat.txt')

    mod.init_optimizer(optimizer=optimizer_name,
                       optimizer_params=optimizer_params)

    eval_metric = MultiBoxMetric()
    # run fit net, every n epochs we run evaluation network to get mAP
    if voc07_metric:
        map_metric = VOC07MApMetric(ovp_thresh,
                                    use_difficult,
                                    class_names,
                                    pred_idx=4)
        recall_metric = RecallMetric(ovp_thresh, use_difficult, pred_idx=4)
        valid_metric = mx.metric.create([map_metric, recall_metric])
    else:
        valid_metric = MApMetric(ovp_thresh,
                                 use_difficult,
                                 class_names,
                                 pred_idx=4)

    if not use_plateau:
        mod.fit(train_iter,
                eval_data=val_iter,
                eval_metric=eval_metric,
                validation_metric=valid_metric,
                batch_end_callback=batch_end_callback,
                epoch_end_callback=epoch_end_callback,
                optimizer=optimizer_name,
                optimizer_params=optimizer_params,
                begin_epoch=begin_epoch,
                num_epoch=end_epoch,
                initializer=mx.init.Xavier(),
                arg_params=args,
                aux_params=auxs,
                allow_missing=True,
                monitor=monitor)
    else:
        mod.fit(train_iter,
                plateau_lr,
                plateau_metric=plateau_metric,
                fn_curr_model=prefix + '-1000.params',
                plateau_backtrace=False,
                eval_data=val_iter,
                eval_metric=eval_metric,
                validation_metric=valid_metric,
                validation_period=5,
                kvstore='local',
                batch_end_callback=batch_end_callback,
                epoch_end_callback=epoch_end_callback,
                optimizer=optimizer_name,
                optimizer_params=optimizer_params,
                begin_epoch=begin_epoch,
                num_epoch=end_epoch,
                initializer=mx.init.Xavier(),
                arg_params=args,
                aux_params=auxs,
                allow_missing=True,
                monitor=monitor)
Exemplo n.º 10
0
def evaluate_net(net,
                 path_imgrec,
                 num_classes,
                 mean_pixels,
                 data_shape,
                 model_prefix,
                 epoch,
                 ctx=mx.cpu(),
                 batch_size=1,
                 path_imglist="",
                 nms_thresh=0.45,
                 force_nms=False,
                 ovp_thresh=0.5,
                 use_difficult=False,
                 class_names=None,
                 voc07_metric=False,
                 use_second_network=False,
                 net1=None,
                 path_imgrec1=None,
                 epoch1=None,
                 model_prefix1=None,
                 data_shape1=None):
    """
    evalute network given validation record file

    Parameters:
    ----------
    net : str or None
        Network name or use None to load from json without modifying
    path_imgrec : str
        path to the record validation file
    path_imglist : str
        path to the list file to replace labels in record file, optional
    num_classes : int
        number of classes, not including background
    mean_pixels : tuple
        (mean_r, mean_g, mean_b)
    data_shape : tuple or int
        (3, height, width) or height/width
    model_prefix : str
        model prefix of saved checkpoint
    epoch : int
        load model epoch
    ctx : mx.ctx
        mx.gpu() or mx.cpu()
    batch_size : int
        validation batch size
    nms_thresh : float
        non-maximum suppression threshold
    force_nms : boolean
        whether suppress different class objects
    ovp_thresh : float
        AP overlap threshold for true/false postives
    use_difficult : boolean
        whether to use difficult objects in evaluation if applicable
    class_names : comma separated str
        class names in string, must correspond to num_classes if set
    voc07_metric : boolean
        whether to use 11-point evluation as in VOC07 competition
    """
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # args
    if isinstance(data_shape, int):
        data_shape = (3, data_shape, data_shape)
    elif isinstance(data_shape, list):
        data_shape = (3, data_shape[0], data_shape[1])
    assert len(data_shape) == 3 and data_shape[0] == 3
    # model_prefix += '_' + str(data_shape[1])

    # iterator
    #eval_iter = DetRecordIter(path_imgrec, batch_size, data_shape,
    #                          path_imglist=path_imglist, **cfg.valid)
    curr_path = os.path.abspath(os.path.dirname(__file__))
    imdb_val = load_caltech(image_set='val',
                            caltech_path=os.path.join(
                                curr_path, '..', 'data',
                                'caltech-pedestrian-dataset-converter'),
                            shuffle=False)
    eval_iter = DetIter(imdb_val, batch_size, (data_shape[1], data_shape[2]), \
                       mean_pixels=[128, 128, 128], rand_samplers=[], \
                       rand_mirror=False, shuffle=False, rand_seed=None, \
                       is_train=True, max_crop_trial=50)
    # model params
    load_net, args, auxs = mx.model.load_checkpoint(model_prefix, epoch)
    # network
    if net is None:
        net = load_net
    else:
        #net = get_symbol(net, data_shape[1], num_classes=num_classes,
        net = get_symbol_concat(net,
                                data_shape[1],
                                num_classes=num_classes,
                                nms_thresh=nms_thresh,
                                force_suppress=force_nms)
    if not 'label' in net.list_arguments():
        label = mx.sym.Variable(name='label')
        label2 = mx.sym.Variable(name='label2')
        net = mx.sym.Group([net, label, label2])

    # init module
    #mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx,
    mod = mx.mod.Module(net,
                        label_names=('label', 'label2'),
                        logger=logger,
                        context=ctx,
                        fixed_param_names=net.list_arguments())
    mod.bind(data_shapes=eval_iter.provide_data,
             label_shapes=eval_iter.provide_label)
    mod.set_params(args, auxs, allow_missing=False, force_init=True)

    if voc07_metric:
        #metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=1)
        metric = VOC07MApMetric(
            ovp_thresh,
            use_difficult,
            class_names,
            pred_idx=[0, 1],
            output_names=['detection_output', 'detection2_output'],
            label_names=['label', 'label2'])
    else:
        #metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=1)
        metric = MApMetric(
            ovp_thresh,
            use_difficult,
            class_names,
            pred_idx=[0, 1],
            output_names=['detection_output', 'detection2_output'],
            label_names=['label', 'label2'])

    # run evaluation
    if not use_second_network:
        results = mod.score(eval_iter, metric, num_batch=None)
        for k, v in results:
            print("{}: {}".format(k, v))
    else:
        logging.basicConfig()
        logger1 = logging.getLogger()
        logger1.setLevel(logging.INFO)

        # load sub network
        if isinstance(data_shape1, int):
            data_shape1 = (3, data_shape1, data_shape1)
        elif isinstance(data_shape1, list):
            data_shape1 = (3, data_shape1[0], data_shape1[1])
        assert len(data_shape1) == 3 and data_shape1[0] == 3

        # iterator
        eval_iter1 = DetRecordIter(path_imgrec1,
                                   batch_size,
                                   data_shape1,
                                   path_imglist=path_imglist,
                                   **cfg.valid)
        # model params
        load_net1, args1, auxs1 = mx.model.load_checkpoint(
            model_prefix1, epoch1)
        # network
        if net1 is None:
            net1 = load_net1
        else:
            net1 = net
        if 'label' not in net1.list_arguments():
            label1 = mx.sym.Variable(name='label')
            net1 = mx.sym.Group([net1, label1])

        # init module
        mod1 = mx.mod.Module(net1,
                             label_names=('label', ),
                             logger=logger1,
                             context=ctx,
                             fixed_param_names=net1.list_arguments())
        mod1.bind(data_shapes=eval_iter1.provide_data,
                  label_shapes=eval_iter1.provide_label)
        mod1.set_params(args1, auxs1, allow_missing=False, force_init=True)

        if voc07_metric:
            metric1 = VOC07MApMetric(ovp_thresh, use_difficult, class_names)
        else:
            metric1 = MApMetric(ovp_thresh, use_difficult, class_names)

        # filepath = '/home/binghao/workspace/MXNet-SSD/matlab/kitti/outputs/ssd/'
        filepath1 = '/home/binghao/workspace/MXNet-SSD/matlab/kitti/outputs/ssd_small/'
        # mod.score_m(filepath, eval_iter, metric, num_batch=None)
        mod1.score_m(filepath1, eval_iter1, metric1, num_batch=None)
Exemplo n.º 11
0
def evaluate_net(net, imdb, mean_pixels, data_shape,
                 model_prefix, epoch, ctx=mx.cpu(), batch_size=1,
                 nms_thresh=0.45, force_nms=False,
                 ovp_thresh=0.5, use_difficult=False,
                 voc07_metric=False):
    """
    evalute network given validation record file

    Parameters:
    ----------
    net : str or None
        Network name or use None to load from json without modifying
    path_imgrec : str
        path to the record validation file
    path_imglist : str
        path to the list file to replace labels in record file, optional
    num_classes : int
        number of classes, not including background
    mean_pixels : tuple
        (mean_r, mean_g, mean_b)
    data_shape : tuple or int
        (3, height, width) or height/width
    model_prefix : str
        model prefix of saved checkpoint
    epoch : int
        load model epoch
    ctx : mx.ctx
        mx.gpu() or mx.cpu()
    batch_size : int
        validation batch size
    nms_thresh : float
        non-maximum suppression threshold
    force_nms : boolean
        whether suppress different class objects
    ovp_thresh : float
        AP overlap threshold for true/false postives
    use_difficult : boolean
        whether to use difficult objects in evaluation if applicable
    class_names : comma separated str
        class names in string, must correspond to num_classes if set
    voc07_metric : boolean
        whether to use 11-point evluation as in VOC07 competition
    """
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    num_classes = imdb.num_classes
    class_names = imdb.classes

    # args
    if isinstance(data_shape, int):
        data_shape = (3, data_shape, data_shape)
    assert len(data_shape) == 3 and data_shape[0] == 3
    model_prefix += '_' + str(data_shape[1])

    # iterator
    eval_iter = FaceTestIter(imdb, mean_pixels, img_stride=128, fix_hw=True)
    # model params
    load_net, args, auxs = mx.model.load_checkpoint(model_prefix, epoch)
    # network
    if net is None:
        net = load_net
    else:
        net = get_symbol(net, data_shape[1], num_classes=num_classes,
            nms_thresh=nms_thresh, force_suppress=force_nms)
    if not 'label' in net.list_arguments():
        label = mx.sym.Variable(name='label')
        net = mx.sym.Group([net, label])

    # init module
    mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx,
        fixed_param_names=net.list_arguments())
    mod.bind(data_shapes=eval_iter.provide_data, label_shapes=eval_iter.provide_label)
    mod.set_params(args, auxs, allow_missing=False, force_init=True)

    # run evaluation
    if voc07_metric:
        metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names)
    else:
        metric = MApMetric(ovp_thresh, use_difficult, class_names)

    results = []
    for i, (datum, im_info) in enumerate(eval_iter):
        mod.reshape(data_shapes=datum.provide_data, label_shapes=datum.provide_label)
        mod.forward(datum)

        preds = mod.get_outputs()

        det0 = preds[0][0].asnumpy() # (n_anchor, 6)
        det0 = do_nms(det0, 1, nms_thresh)
        preds[0][0] = mx.nd.array(det0, ctx=preds[0].context)

        sy, sx, _ = im_info['im_shape']
        scaler = mx.nd.array((1.0, sx, sy, sx, sy, 1.0))
        scaler = mx.nd.reshape(scaler, (1, 1, -1))

        datum.label[0] *= scaler
        metric.update(datum.label, preds)

        if i % 10 == 0:
            print('processed {} images.'.format(i))
        # if i == 10:
        #     break

    results = metric.get_name_value()
    for k, v in results:
        print("{}: {}".format(k, v))
Exemplo n.º 12
0
def evaluate_net(net,
                 path_imgrec,
                 num_classes,
                 mean_pixels,
                 data_shape,
                 model_prefix,
                 epoch,
                 ctx=mx.cpu(),
                 batch_size=1,
                 path_imglist="",
                 nms_thresh=0.45,
                 force_nms=False,
                 ovp_thresh=0.5,
                 use_difficult=False,
                 class_names=None,
                 voc07_metric=False,
                 frequent=20):
    """
    evalute network given validation record file

    Parameters:
    ----------
    net : str or None
        Network name or use None to load from json without modifying
    path_imgrec : str
        path to the record validation file
    path_imglist : str
        path to the list file to replace labels in record file, optional
    num_classes : int
        number of classes, not including background
    mean_pixels : tuple
        (mean_r, mean_g, mean_b)
    data_shape : tuple or int
        (3, height, width) or height/width
    model_prefix : str
        model prefix of saved checkpoint
    epoch : int
        load model epoch
    ctx : mx.ctx
        mx.gpu() or mx.cpu()
    batch_size : int
        validation batch size
    nms_thresh : float
        non-maximum suppression threshold
    force_nms : boolean
        whether suppress different class objects
    ovp_thresh : float
        AP overlap threshold for true/false postives
    use_difficult : boolean
        whether to use difficult objects in evaluation if applicable
    class_names : comma separated str
        class names in string, must correspond to num_classes if set
    voc07_metric : boolean
        whether to use 11-point evluation as in VOC07 competition
    frequent : int
        frequency to print out validation status
    """
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # args
    if isinstance(data_shape, int):
        data_shape = (3, data_shape, data_shape)
    assert len(data_shape) == 3 and data_shape[0] == 3
    #model_prefix += '_' + str(data_shape[1])

    # iterator
    eval_iter = DetRecordIter(path_imgrec,
                              batch_size,
                              data_shape,
                              mean_pixels=mean_pixels,
                              label_pad_width=350,
                              path_imglist=path_imglist,
                              **cfg.valid)
    # model params
    load_net, args, auxs = mx.model.load_checkpoint(model_prefix, epoch)
    # network
    if net is None:
        net = load_net
    else:
        net = get_symbol(net,
                         data_shape[1],
                         num_classes=num_classes,
                         nms_thresh=nms_thresh,
                         force_suppress=force_nms)
    if not 'label' in net.list_arguments():
        label = mx.sym.Variable(name='label')
        net = mx.sym.Group([net, label])

    # init module
    mod = mx.mod.Module(net,
                        label_names=('label', ),
                        logger=logger,
                        context=ctx,
                        fixed_param_names=net.list_arguments())
    mod.bind(data_shapes=eval_iter.provide_data,
             label_shapes=eval_iter.provide_label)
    mod.set_params(args, auxs, allow_missing=False, force_init=True)

    # run evaluation
    if voc07_metric:
        metric = VOC07MApMetric(ovp_thresh,
                                use_difficult,
                                class_names,
                                roc_output_path=os.path.join(
                                    os.path.dirname(model_prefix), 'roc'))
    else:
        metric = MApMetric(ovp_thresh,
                           use_difficult,
                           class_names,
                           roc_output_path=os.path.join(
                               os.path.dirname(model_prefix), 'roc'))

    posemetric = PoseMetric(
        LINEMOD_path='/data/ZHANGXIN/DATASETS/SIXD_CHALLENGE/LINEMOD/',
        classes=class_names)

    # visualize bb8 results
    # for nbatch, eval_batch in tqdm(enumerate(eval_iter)):
    #     mod.forward(eval_batch)
    #     preds = mod.get_outputs(merge_multi_context=True)
    #
    #     labels = eval_batch.label[0].asnumpy()
    #     # get generated multi label from network
    #     cls_prob = preds[0]
    #     loc_pred = preds[4]
    #     bb8_pred = preds[5]
    #     anchors = preds[6]
    #
    #     bb8dets = BB8MultiBoxDetection(cls_prob, loc_pred, bb8_pred, anchors, nms_threshold=0.5, force_suppress=False,
    #                                   variances=(0.1, 0.1, 0.2, 0.2), nms_topk=400)
    #     bb8dets = bb8dets.asnumpy()
    #
    #     for nsample, sampleDet in enumerate(bb8dets):
    #         image = eval_batch.data[0][nsample].asnumpy()
    #         image += np.array(mean_pixels).reshape((3, 1, 1))
    #         image = np.transpose(image, axes=(1, 2, 0))
    #         draw_dets = []
    #         draw_cids = []
    #
    #         for instanceDet in sampleDet:
    #             if instanceDet[0] == -1:
    #                 continue
    #             else:
    #                 cid = instanceDet[0].astype(np.int16)
    #                 indices = np.where(sampleDet[:, 0] == cid)[0]
    #
    #                 if indices.size > 0:
    #                     draw_dets.append(sampleDet[indices[0], 6:])
    #                     draw_cids.append(cid)
    #                     sampleDet = np.delete(sampleDet, indices, axis=0)
    #                     show_BB8(image / 255., np.transpose(draw_dets[-1].reshape((-1, 8, 2)), axes=(0,2,1)), [cid],
    #                              plot_path='./output/bb8results/{:04d}_{}'.format(nbatch * batch_size + nsample, class_names[cid]))
    #
    #         # draw_dets = np.array(draw_dets)
    #         # draw_cids = np.array(draw_cids)
    #
    #         # show_BB8(image / 255., np.transpose(draw_dets.reshape((-1, 8, 2)), axes=(0,2,1)), draw_cids,
    #         #          plot_path='./output/bb8results/{:04d}'.format(nbatch * batch_size + nsample))

    # quantitive results
    results = mod.score(eval_iter, [metric, posemetric],
                        num_batch=None,
                        batch_end_callback=mx.callback.Speedometer(
                            batch_size, frequent=frequent, auto_reset=False))

    results_save_path = os.path.join(os.path.dirname(model_prefix),
                                     'evaluate_results')
    with open(results_save_path, 'w') as f:
        for k, v in results:
            print("{}: {}".format(k, v))
            f.write("{}: {}\n".format(k, v))
        f.close()

    reproj_save_path = os.path.join(os.path.dirname(model_prefix),
                                    'reprojection_error')
    with open(reproj_save_path, 'wb') as f:
        # for k, v in metric.Reproj.items():
        #     f.write("{}: {}\n".format(k, v))
        pickle.dump(posemetric.Reproj, f, protocol=2)
        f.close()

    count_save_path = os.path.join(os.path.dirname(model_prefix), 'gt_count')
    with open(count_save_path, 'wb') as f:
        # for k, v in metric.counts.items():
        #     f.write("{}: {}\n".format(k, v))
        pickle.dump(posemetric.counts, f, protocol=2)
        f.close()
Exemplo n.º 13
0
def train_net(net,
              dataset,
              image_set,
              devkit_path,
              batch_size,
              data_shape,
              mean_pixels,
              resume,
              finetune,
              pretrained,
              epoch,
              prefix,
              ctx,
              begin_epoch,
              end_epoch,
              frequent,
              learning_rate,
              momentum,
              weight_decay,
              lr_refactor_step,
              lr_refactor_ratio,
              year='',
              val_image_set=None,
              val_year='',
              freeze_layer_pattern='',
              label_pad_width=350,
              nms_thresh=0.45,
              force_nms=False,
              ovp_thresh=0.5,
              use_difficult=False,
              voc07_metric=False,
              nms_topk=400,
              force_suppress=False,
              iter_monitor=0,
              monitor_pattern=".*",
              log_file=None):
    """
    Wrapper for training phase.

    Parameters:
    ----------
    net : str
        symbol name for the network structure
    dataset : str
        pascal_voc, imagenet...
    image_set : str
        train, trainval...
    devkit_path : str
        root directory of dataset
    batch_size : int
        training batch-size
    data_shape : int or tuple
        width/height as integer or (3, height, width) tuple
    mean_pixels : tuple of floats
        mean pixel values for red, green and blue
    resume : int
        resume from previous checkpoint if > 0
    finetune : int
        fine-tune from previous checkpoint if > 0
    pretrained : str
        prefix of pretrained model, including path
    epoch : int
        load epoch of either resume/finetune/pretrained model
    prefix : str
        prefix for saving checkpoints
    ctx : [mx.cpu()] or [mx.gpu(x)]
        list of mxnet contexts
    begin_epoch : int
        starting epoch for training, should be 0 if not otherwise specified
    end_epoch : int
        end epoch of training
    frequent : int
        frequency to print out training status
    learning_rate : float
        training learning rate
    momentum : float
        trainig momentum
    weight_decay : float
        training weight decay param
    lr_refactor_ratio : float
        multiplier for reducing learning rate
    lr_refactor_step : comma separated integers
        at which epoch to rescale learning rate, e.g. '30, 60, 90'
    year : str
        2007, 2012 or combinations splitted by comma
    freeze_layer_pattern : str
        regex pattern for layers need to be fixed
    num_example : int
        number of training images
    label_pad_width : int
        force padding training and validation labels to sync their label widths
    nms_thresh : float
        non-maximum suppression threshold for validation
    force_nms : boolean
        suppress overlaped objects from different classes
    iter_monitor : int
        monitor internal stats in networks if > 0, specified by monitor_pattern
    monitor_pattern : str
        regex pattern for monitoring network stats
    log_file : str
        log to file if enabled
    """
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    if log_file:
        fh = logging.FileHandler(log_file)
        logger.addHandler(fh)

    import ipdb
    ipdb.set_trace()

    # check args
    if isinstance(data_shape, int):
        data_shape = (3, data_shape, data_shape)
    assert len(data_shape) == 3 and data_shape[0] == 3
    prefix += '_' + str(data_shape[1])

    if isinstance(mean_pixels, (int, float)):
        mean_pixels = [mean_pixels, mean_pixels, mean_pixels]
    assert len(mean_pixels) == 3, "must provide all RGB mean values"

    # load dataset
    if dataset == 'pascal_voc':
        imdb = load_pascal(image_set, year, devkit_path,
                           cfg.train['init_shuffle'])
        if val_image_set and val_image_set != '' and val_year:
            val_imdb = load_pascal(val_image_set, val_year, devkit_path, False)
        else:
            val_imdb = None
    else:
        raise NotImplementedError("Dataset " + dataset + " not supported")

    rand_scaler = RandScaler(min_scale=cfg.train['min_aug_scale'],
                             max_scale=cfg.train['max_aug_scale'],
                             min_gt_scale=cfg.train['min_aug_gt_scale'],
                             max_trials=cfg.train['max_aug_trials'],
                             max_sample=cfg.train['max_aug_sample'],
                             patch_size=cfg.train['aug_patch_size'])
    # init data iterator
    train_iter = DetIter(imdb,
                         batch_size,
                         data_shape,
                         mean_pixels,
                         rand_scaler,
                         cfg.train['rand_mirror'],
                         cfg.train['epoch_shuffle'],
                         cfg.train['seed'],
                         is_train=True)
    # TODO: enable val_iter
    val_iter = None
    # if val_imdb:
    #     val_iter = DetIter(val_imdb, batch_size, data_shape, mean_pixels,
    #                        cfg.valid.rand_scaler, cfg.valid.rand_mirror,
    #                        cfg.valid.epoch_shuffle, cfg.valid.seed,
    #                        is_train=True)
    # else:
    #     val_iter = None

    # load symbol
    sys.path.append(os.path.join(cfg.ROOT_DIR, 'symbol'))
    symbol_module = importlib.import_module("symbol_" + net)
    net = symbol_module.get_symbol_train(imdb.num_classes,
                                         nms_thresh=nms_thresh,
                                         force_suppress=force_suppress,
                                         nms_topk=nms_topk)

    # define layers with fixed weight/bias
    if freeze_layer_pattern.strip():
        re_prog = re.compile(freeze_layer_pattern)
        fixed_param_names = [
            name for name in net.list_arguments() if re_prog.match(name)
        ]
    else:
        fixed_param_names = None

    # load pretrained or resume from previous state
    ctx_str = '(' + ','.join([str(c) for c in ctx]) + ')'
    if resume > 0:
        logger.info("Resume training with {} from epoch {}".format(
            ctx_str, resume))
        _, args, auxs = mx.model.load_checkpoint(prefix, resume)
        begin_epoch = resume
    elif finetune > 0:
        logger.info("Start finetuning with {} from epoch {}".format(
            ctx_str, finetune))
        _, args, auxs = mx.model.load_checkpoint(prefix, finetune)
        begin_epoch = finetune
        # the prediction convolution layers name starts with relu, so it's fine
        fixed_param_names = [name for name in net.list_arguments() \
            if name.startswith('conv')]
    elif pretrained:
        logger.info("Start training with {} from pretrained model {}".format(
            ctx_str, pretrained))
        _, args, auxs = mx.model.load_checkpoint(pretrained, epoch)
        args = convert_pretrained(pretrained, args)
    else:
        logger.info("Experimental: start training from scratch with {}".format(
            ctx_str))
        args = None
        auxs = None
        fixed_param_names = None

    # helper information
    if fixed_param_names:
        logger.info("Freezed parameters: [" + ','.join(fixed_param_names) +
                    ']')

    # init training module
    mod = mx.mod.Module(net,
                        label_names=('label', ),
                        logger=logger,
                        context=ctx,
                        fixed_param_names=fixed_param_names)

    # fit parameters
    batch_end_callback = mx.callback.Speedometer(train_iter.batch_size,
                                                 frequent=frequent)
    epoch_end_callback = mx.callback.do_checkpoint(prefix)
    learning_rate, lr_scheduler = get_lr_scheduler(learning_rate,
                                                   lr_refactor_step,
                                                   lr_refactor_ratio,
                                                   num_example, batch_size,
                                                   begin_epoch)
    optimizer_params = {
        'learning_rate': learning_rate,
        'wd': weight_decay,
        'clip_gradient': 10.0,
        'rescale_grad': 1.0,
        'lr_scheduler': lr_scheduler
    }
    # optimizer_params={'learning_rate':learning_rate,
    #                   'momentum':momentum,
    #                   'wd':weight_decay,
    #                   'lr_scheduler':lr_scheduler,
    #                   'clip_gradient':None,
    #                   'rescale_grad': 1.0}
    monitor = mx.mon.Monitor(
        iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None

    # run fit net, every n epochs we run evaluation network to get mAP
    if voc07_metric:
        valid_metric = VOC07MApMetric(ovp_thresh,
                                      use_difficult,
                                      imdb.classes,
                                      pred_idx=3)
    else:
        valid_metric = MApMetric(ovp_thresh,
                                 use_difficult,
                                 imdb.classes,
                                 pred_idx=3)

    mod.fit(train_iter,
            val_iter,
            eval_metric=MultiBoxMetric(),
            validation_metric=valid_metric,
            batch_end_callback=batch_end_callback,
            epoch_end_callback=epoch_end_callback,
            optimizer='adam',
            optimizer_params=optimizer_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch,
            initializer=mx.init.Xavier(),
            arg_params=args,
            aux_params=auxs,
            allow_missing=True,
            monitor=monitor)
Exemplo n.º 14
0
def train_multitask(net, train_path, num_classes, batch_size,
              data_shape, mean_pixels, resume, finetune, pretrained, epoch,
              prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate,
              momentum, weight_decay, lr_refactor_step, lr_refactor_ratio,
              freeze_layer_pattern='',
              num_example=10000, label_pad_width=350,
              nms_thresh=0.45, force_nms=False, ovp_thresh=0.5,
              use_difficult=False, class_names=None,
              voc07_metric=False, nms_topk=400, force_suppress=False,
              train_list="", val_path="", val_list="", iter_monitor=0,
              monitor_pattern=".*", log_file=None):
    """
    Wrapper for training phase.

    Parameters:
    ----------
    net : str
        symbol name for the network structure
    train_path : str
        record file path for training
    num_classes : int
        number of object classes, not including background
    batch_size : int
        training batch-size
    data_shape : int or tuple
        width/height as integer or (3, height, width) tuple
    mean_pixels : tuple of floats
        mean pixel values for red, green and blue
    resume : int
        resume from previous checkpoint if > 0
    finetune : int
        fine-tune from previous checkpoint if > 0
    pretrained : str
        prefix of pretrained model, including path
    epoch : int
        load epoch of either resume/finetune/pretrained model
    prefix : str
        prefix for saving checkpoints
    ctx : [mx.cpu()] or [mx.gpu(x)]
        list of mxnet contexts
    begin_epoch : int
        starting epoch for training, should be 0 if not otherwise specified
    end_epoch : int
        end epoch of training
    frequent : int
        frequency to print out training status
    learning_rate : float
        training learning rate
    momentum : float
        trainig momentum
    weight_decay : float
        training weight decay param
    lr_refactor_ratio : float
        multiplier for reducing learning rate
    lr_refactor_step : comma separated integers
        at which epoch to rescale learning rate, e.g. '30, 60, 90'
    freeze_layer_pattern : str
        regex pattern for layers need to be fixed
    num_example : int
        number of training images
    label_pad_width : int
        force padding training and validation labels to sync their label widths
    nms_thresh : float
        non-maximum suppression threshold for validation
    force_nms : boolean
        suppress overlaped objects from different classes
    train_list : str
        list file path for training, this will replace the embeded labels in record
    val_path : str
        record file path for validation
    val_list : str
        list file path for validation, this will replace the embeded labels in record
    iter_monitor : int
        monitor internal stats in networks if > 0, specified by monitor_pattern
    monitor_pattern : str
        regex pattern for monitoring network stats
    log_file : str
        log to file if enabled
    """
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    if log_file:
        fh = logging.FileHandler(log_file)
        logger.addHandler(fh)

    # check args
    if isinstance(data_shape, int):
        data_shape = (3, data_shape, data_shape)
    assert len(data_shape) == 3 and data_shape[0] == 3
    prefix += '_' + net + '_' + str(data_shape[1])

    if isinstance(mean_pixels, (int, float)):
        mean_pixels = [mean_pixels, mean_pixels, mean_pixels]
    assert len(mean_pixels) == 3, "must provide all RGB mean values"

    logger.info(str({"train_path":train_path,"batch_size":batch_size,"data_shape":data_shape}))
    train_iter = MultiTaskRecordIter(train_path, batch_size, data_shape, mean_pixels=mean_pixels,
        label_pad_width=label_pad_width, path_imglist=train_list, **cfg.train)

    if val_path:
        val_iter = MultiTaskRecordIter(val_path, batch_size, data_shape, mean_pixels=mean_pixels,
            label_pad_width=label_pad_width, path_imglist=val_list, **cfg.valid)
    else:
        val_iter = None

    # load symbol
    net = get_symbol_train(net, data_shape[1], num_classes=num_classes,
        nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk)

    # define layers with fixed weight/bias
    if freeze_layer_pattern.strip():
        re_prog = re.compile(freeze_layer_pattern)
        fixed_param_names = [name for name in net.list_arguments() if re_prog.match(name)]
    else:
        fixed_param_names = None

    # load pretrained or resume from previous state
    ctx_str = '('+ ','.join([str(c) for c in ctx]) + ')'
    if resume > 0:
        logger.info("Resume training with {} from epoch {}"
            .format(ctx_str, resume))
        _, args, auxs = mx.model.load_checkpoint(prefix, resume)
        begin_epoch = resume
    elif finetune > 0:
        logger.info("Start finetuning with {} from epoch {}"
            .format(ctx_str, finetune))
        _, args, auxs = mx.model.load_checkpoint(prefix, finetune)
        begin_epoch = finetune
        # the prediction convolution layers name starts with relu, so it's fine
        fixed_param_names = [name for name in net.list_arguments() \
            if name.startswith('conv')]
    elif pretrained:
        logger.info("Start training with {} from pretrained model {}"
            .format(ctx_str, pretrained))
        _, args, auxs = mx.model.load_checkpoint(pretrained, epoch)
        args = convert_pretrained(pretrained, args)
    else:
        logger.info("Experimental: start training from scratch with {}"
            .format(ctx_str))
        args = None
        auxs = None
        fixed_param_names = None

    # helper information
    if fixed_param_names:
        logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']')

    # init training module
    logger.info("Creating Base Module ...")
    mod = mx.mod.Module(net, label_names=('label_det','seg_out_label',), logger=logger, context=ctx,
                        fixed_param_names=fixed_param_names)

    # fit parameters
    batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent)
    epoch_end_callback = mx.callback.do_checkpoint(prefix)
    learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step,
        lr_refactor_ratio, num_example, batch_size, begin_epoch)
    optimizer_params={'learning_rate':learning_rate,
                      'momentum':momentum,
                      'wd':weight_decay,
                      'lr_scheduler':lr_scheduler,
                      'clip_gradient':None,
                      'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0 }
    monitor = mx.mon.Monitor(iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None

    # run fit net, every n epochs we run evaluation network to get mAP
    if voc07_metric:
        valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3)
    else:
        valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3)

    from pprint import pprint
    import numpy as np
    import cv2
    from palette import color2index, index2color

    pprint(optimizer_params)
    np.set_printoptions(formatter={"float":lambda x:"%.3f "%x},suppress=True)

    ############### uncomment the following lines to visualize network ###########################
    internals = net.get_internals()
    print(net)
    # print(internals)
    
    ############### uncomment the following lines to visualize data ###########################
    data_batch = train_iter.next()
    pprint({"data":data_batch.data[0].shape,
            "label_det":data_batch.label[0].shape,
            "seg_out_label":data_batch.label[1].shape})
    data = data_batch.data[0].asnumpy()
    label = data_batch.label[0].asnumpy()
    segmt = data_batch.label[1].asnumpy()
    for ii in range(data.shape[0]):
        img = data[ii,:,:,:]
        seg = segmt[ii,:,:]
        print label[ii,:5,:]
        img = np.squeeze(img)
        img = np.swapaxes(img, 0, 2)
        img = np.swapaxes(img, 0, 1)
        img = (img + np.array([123.68, 116.779, 103.939]).reshape((1,1,3))).astype(np.uint8)
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        rois = label[ii,:,:]
        hh, ww, ch = img.shape
        for lidx in range(rois.shape[0]):
            roi = rois[lidx,:]
            if roi[0]>=0:
                cv2.rectangle(img, (int(roi[1]*ww),int(roi[2]*hh)), (int(roi[3]*ww),int(roi[4]*hh)), (0,0,128))
                cls_id = int(roi[0])
                bbox = [int(roi[1]*ww),int(roi[2]*hh),int(roi[3]*ww),int(roi[4]*hh)]
                text = '%s %.0fm' % (class_names[cls_id], roi[5]*255.)
                putText(img,bbox,text)
        disp = np.zeros((hh*2, ww, ch),np.uint8)
        disp[:hh,:, :] = img.astype(np.uint8)
        disp[hh:,:, :] = index2color(seg)
        cv2.imshow("img", disp)
        if cv2.waitKey()&0xff==27: exit(0)
        
    mod.fit(train_iter,
            val_iter,
            eval_metric=MultiBoxMetric(),
            validation_metric=valid_metric,
            batch_end_callback=batch_end_callback,
            epoch_end_callback=epoch_end_callback,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch,
            initializer=mx.init.Xavier(),
            arg_params=args,
            aux_params=auxs,
            allow_missing=True,
            monitor=monitor)
Exemplo n.º 15
0
def train_net(network,
              train_path,
              num_classes,
              batch_size,
              data_shape,
              mean_pixels,
              resume,
              finetune,
              pretrained,
              epoch,
              prefix,
              ctx,
              begin_epoch,
              end_epoch,
              frequent,
              learning_rate,
              momentum,
              weight_decay,
              lr_refactor_step,
              lr_refactor_ratio,
              freeze_layer_pattern='',
              num_example=10000,
              label_pad_width=350,
              nms_thresh=0.45,
              force_nms=False,
              ovp_thresh=0.5,
              use_difficult=False,
              class_names=None,
              voc07_metric=False,
              nms_topk=400,
              force_suppress=False,
              train_list="",
              val_path="",
              val_list="",
              iter_monitor=0,
              monitor_pattern=".*",
              log_file=None,
              optimizer='sgd',
              tensorboard=False,
              checkpoint_period=5,
              min_neg_samples=0):
    """
    Wrapper for training phase.

    Parameters:
    ----------
    net : str
        symbol name for the network structure
    train_path : str
        record file path for training
    num_classes : int
        number of object classes, not including background
    batch_size : int
        training batch-size
    data_shape : int or tuple
        width/height as integer or (3, height, width) tuple
    mean_pixels : tuple of floats
        mean pixel values for red, green and blue
    resume : int
        resume from previous checkpoint if > 0
    finetune : int
        fine-tune from previous checkpoint if > 0
    pretrained : str
        prefix of pretrained model, including path
    epoch : int
        load epoch of either resume/finetune/pretrained model
    prefix : str
        prefix for saving checkpoints
    ctx : [mx.cpu()] or [mx.gpu(x)]
        list of mxnet contexts
    begin_epoch : int
        starting epoch for training, should be 0 if not otherwise specified
    end_epoch : int
        end epoch of training
    frequent : int
        frequency to print out training status
    optimizer : str
        usage of different optimizers, other then default sgd
    learning_rate : float
        training learning rate
    momentum : float
        trainig momentum
    weight_decay : float
        training weight decay param
    lr_refactor_ratio : float
        multiplier for reducing learning rate
    lr_refactor_step : comma separated integers
        at which epoch to rescale learning rate, e.g. '30, 60, 90'
    freeze_layer_pattern : str
        regex pattern for layers need to be fixed
    num_example : int
        number of training images
    label_pad_width : int
        force padding training and validation labels to sync their label widths
    nms_thresh : float
        non-maximum suppression threshold for validation
    force_nms : boolean
        suppress overlaped objects from different classes
    train_list : str
        list file path for training, this will replace the embeded labels in record
    val_path : str
        record file path for validation
    val_list : str
        list file path for validation, this will replace the embeded labels in record
    iter_monitor : int
        monitor internal stats in networks if > 0, specified by monitor_pattern
    monitor_pattern : str
        regex pattern for monitoring network stats
    log_file : str
        log to file if enabled
    tensorboard : bool
        record logs into tensorboard
    min_neg_samples : int
        always have some negative examples, no matter how many positive there are.
        this is useful when training on images with no ground-truth.
    checkpoint_period : int
        a checkpoint will be saved every "checkpoint_period" epochs
    """
    # check actual number of train_images
    if os.path.exists(train_path.replace('rec', 'idx')):
        with open(train_path.replace('rec', 'idx'), 'r') as f:
            txt = f.readlines()
        num_example = len(txt)
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    if log_file:
        log_file_path = os.path.join(os.path.dirname(prefix), log_file)
        if not os.path.exists(os.path.dirname(log_file_path)):
            os.makedirs(os.path.dirname(log_file_path))
        fh = logging.FileHandler(log_file_path)
        logger.addHandler(fh)

    # check args
    if isinstance(data_shape, int):
        data_shape = (3, data_shape, data_shape)
    assert len(data_shape) == 3 and data_shape[0] == 3
    if prefix.endswith('_'):
        prefix += '_' + str(data_shape[1])

    if isinstance(mean_pixels, (int, float)):
        mean_pixels = [mean_pixels, mean_pixels, mean_pixels]
    assert len(mean_pixels) == 3, "must provide all RGB mean values"

    train_iter = DetRecordIter(train_path,
                               batch_size,
                               data_shape,
                               mean_pixels=mean_pixels,
                               label_pad_width=label_pad_width,
                               path_imglist=train_list,
                               **cfg.train)

    if val_path:
        val_iter = DetRecordIter(val_path,
                                 batch_size,
                                 data_shape,
                                 mean_pixels=mean_pixels,
                                 label_pad_width=label_pad_width,
                                 path_imglist=val_list,
                                 **cfg.valid)
    else:
        val_iter = None

    # load symbol
    net = get_symbol_train(network,
                           data_shape[1],
                           num_classes=num_classes,
                           nms_thresh=nms_thresh,
                           force_suppress=force_suppress,
                           nms_topk=nms_topk,
                           minimum_negative_samples=min_neg_samples)

    # define layers with fixed weight/bias
    if freeze_layer_pattern.strip():
        re_prog = re.compile(freeze_layer_pattern)
        fixed_param_names = [
            name for name in net.list_arguments() if re_prog.match(name)
        ]
    else:
        fixed_param_names = None

    # load pretrained or resume from previous state
    ctx_str = '(' + ','.join([str(c) for c in ctx]) + ')'
    if resume > 0:
        logger.info("Resume training with {} from epoch {}".format(
            ctx_str, resume))
        _, args, auxs = mx.model.load_checkpoint(prefix, resume)
        begin_epoch = resume
    elif finetune > 0:
        logger.info("Start finetuning with {} from epoch {}".format(
            ctx_str, finetune))

        #        _, args, auxs = mx.model.load_checkpoint(prefix, finetune)
        # Gang Chen changed

        _, args, auxs = mx.model.load_checkpoint(pretrained, finetune)
        begin_epoch = finetune
        # check what layers mismatch with the loaded parameters
        exe = net.simple_bind(mx.cpu(),
                              data=(1, 3, 300, 300),
                              label=(1, 1, 5),
                              grad_req='null')
        arg_dict = exe.arg_dict
        fixed_param_names = []
        for k, v in arg_dict.items():
            if k in args:
                if v.shape != args[k].shape:
                    del args[k]
                    logging.info("Removed %s" % k)
                else:
                    if not 'pred' in k:
                        #Gang Chen
                        #fixed_param_names.append(k)
                        pass

    elif pretrained and not finetune:
        logger.info("Start training with {} from pretrained model {}".format(
            ctx_str, pretrained))
        _, args, auxs = mx.model.load_checkpoint(pretrained, epoch)
        args = convert_pretrained(pretrained, args)
    else:
        logger.info("Experimental: start training from scratch with {}".format(
            ctx_str))
        args = None
        auxs = None
        fixed_param_names = None

    # helper information
    if fixed_param_names:
        logger.info("Freezed parameters: [" + ','.join(fixed_param_names) +
                    ']')

    # visualize net - both train and test
    net_visualization(net=net,
                      network=network,
                      data_shape=data_shape[2],
                      output_dir=os.path.dirname(prefix),
                      train=True)
    net_visualization(net=None,
                      network=network,
                      data_shape=data_shape[2],
                      output_dir=os.path.dirname(prefix),
                      train=False,
                      num_classes=num_classes)

    # init training module
    mod = mx.mod.Module(net,
                        label_names=('label', ),
                        logger=logger,
                        context=ctx,
                        fixed_param_names=fixed_param_names)

    batch_end_callback = []
    eval_end_callback = []
    epoch_end_callback = [
        mx.callback.do_checkpoint(prefix, period=checkpoint_period)
    ]

    # add logging to tensorboard
    if tensorboard:
        tensorboard_dir = os.path.join(os.path.dirname(prefix), 'logs')
        if not os.path.exists(tensorboard_dir):
            os.makedirs(os.path.join(tensorboard_dir, 'train', 'scalar'))
            os.makedirs(os.path.join(tensorboard_dir, 'train', 'dist'))
            os.makedirs(os.path.join(tensorboard_dir, 'val', 'roc'))
            os.makedirs(os.path.join(tensorboard_dir, 'val', 'scalar'))
            os.makedirs(os.path.join(tensorboard_dir, 'val', 'images'))
        batch_end_callback.append(
            ParseLogCallback(
                dist_logging_dir=os.path.join(tensorboard_dir, 'train',
                                              'dist'),
                scalar_logging_dir=os.path.join(tensorboard_dir, 'train',
                                                'scalar'),
                logfile_path=log_file_path,
                batch_size=batch_size,
                iter_monitor=iter_monitor,
                frequent=frequent))
        eval_end_callback.append(
            mx.contrib.tensorboard.LogMetricsCallback(
                os.path.join(tensorboard_dir, 'val/scalar'), 'ssd'))
        eval_end_callback.append(
            LogROCCallback(logging_dir=os.path.join(tensorboard_dir,
                                                    'val/roc'),
                           roc_path=os.path.join(os.path.dirname(prefix),
                                                 'roc'),
                           class_names=class_names))
        eval_end_callback.append(
            LogDetectionsCallback(
                logging_dir=os.path.join(tensorboard_dir, 'val/images'),
                images_path=os.path.join(os.path.dirname(prefix), 'images'),
                class_names=class_names,
                batch_size=batch_size,
                mean_pixels=mean_pixels))

    # this callback should be the last in a serie of batch_callbacks
    # since it is resetting the metric evaluation every $frequent batches
    batch_end_callback.append(
        mx.callback.Speedometer(train_iter.batch_size, frequent=frequent))

    learning_rate, lr_scheduler = get_lr_scheduler(learning_rate,
                                                   lr_refactor_step,
                                                   lr_refactor_ratio,
                                                   num_example, batch_size,
                                                   begin_epoch)
    # add possibility for different optimizer
    opt, opt_params = get_optimizer_params(optimizer=optimizer,
                                           learning_rate=learning_rate,
                                           momentum=momentum,
                                           weight_decay=weight_decay,
                                           lr_scheduler=lr_scheduler,
                                           ctx=ctx,
                                           logger=logger)
    # TODO monitor the gradient flow as in 'https://github.com/dmlc/tensorboard/blob/master/docs/tutorial/understanding-vanish-gradient.ipynb'
    monitor = mx.mon.Monitor(
        iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None

    # run fit net, every n epochs we run evaluation network to get mAP
    if voc07_metric:
        valid_metric = VOC07MApMetric(ovp_thresh,
                                      use_difficult,
                                      class_names,
                                      pred_idx=3,
                                      roc_output_path=os.path.join(
                                          os.path.dirname(prefix), 'roc'))
    else:
        valid_metric = MApMetric(ovp_thresh,
                                 use_difficult,
                                 class_names,
                                 pred_idx=3,
                                 roc_output_path=os.path.join(
                                     os.path.dirname(prefix), 'roc'))

    mod.fit(train_iter,
            val_iter,
            eval_metric=MultiBoxMetric(),
            validation_metric=valid_metric,
            batch_end_callback=batch_end_callback,
            eval_end_callback=eval_end_callback,
            epoch_end_callback=epoch_end_callback,
            optimizer=opt,
            optimizer_params=opt_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch,
            initializer=mx.init.Xavier(),
            arg_params=args,
            aux_params=auxs,
            allow_missing=True,
            monitor=monitor)
Exemplo n.º 16
0
def evaluate_net(net,
                 path_imgrec,
                 num_classes,
                 num_batch,
                 mean_pixels,
                 data_shape,
                 model_prefix,
                 epoch,
                 ctx=mx.cpu(),
                 batch_size=32,
                 path_imglist="",
                 nms_thresh=0.45,
                 force_nms=False,
                 ovp_thresh=0.5,
                 use_difficult=False,
                 class_names=None,
                 voc07_metric=False,
                 lite=False):
    """
    evalute network given validation record file

    Parameters:
    ----------
    net : str or None
        Network name or use None to load from json without modifying
    path_imgrec : str
        path to the record validation file
    path_imglist : str
        path to the list file to replace labels in record file, optional
    num_classes : int
        number of classes, not including background
    mean_pixels : tuple
        (mean_r, mean_g, mean_b)
    data_shape : tuple or int
        (3, height, width) or height/width
    model_prefix : str
        model prefix of saved checkpoint
    epoch : int
        load model epoch
    ctx : mx.ctx
        mx.gpu() or mx.cpu()
    batch_size : int
        validation batch size
    nms_thresh : float
        non-maximum suppression threshold
    force_nms : boolean
        whether suppress different class objects
    ovp_thresh : float
        AP overlap threshold for true/false postives
    use_difficult : boolean
        whether to use difficult objects in evaluation if applicable
    class_names : comma separated str
        class names in string, must correspond to num_classes if set
    voc07_metric : boolean
        whether to use 11-point evluation as in VOC07 competition
    """
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # args
    if isinstance(data_shape, int):
        data_shape = (3, data_shape, data_shape)
    assert len(data_shape) == 3 and data_shape[0] == 3
    model_prefix += '_' + str(data_shape[1]) + '_' + str(data_shape[2])

    # iterator
    eval_iter = DetRecordIter(path_imgrec,
                              batch_size,
                              data_shape,
                              mean_pixels=mean_pixels,
                              path_imglist=path_imglist,
                              **cfg.valid)
    # model params
    load_net, args, auxs = mx.model.load_checkpoint(model_prefix, epoch)
    # network
    if net is None:
        net = load_net
    else:
        net = get_symbol(net,
                         data_shape,
                         num_classes=num_classes,
                         nms_thresh=nms_thresh,
                         force_suppress=force_nms,
                         lite=lite)
    if not 'label' in net.list_arguments():
        label = mx.sym.Variable(name='label')
        net = mx.sym.Group([net, label])

    # init module
    mod = mx.mod.Module(net,
                        label_names=('label', ),
                        logger=logger,
                        context=ctx,
                        fixed_param_names=net.list_arguments())
    mod.bind(data_shapes=eval_iter.provide_data,
             label_shapes=eval_iter.provide_label)
    mod.set_params(args, auxs, allow_missing=False, force_init=True)

    # run evaluation
    if voc07_metric:
        metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names)
    else:
        metric = MApMetric(ovp_thresh, use_difficult, class_names)

    num = num_batch * batch_size
    data = [
        mx.random.uniform(-1.0, 1.0, shape=shape, ctx=ctx)
        for _, shape in mod.data_shapes
    ]
    batch = mx.io.DataBatch(data, [])  # empty label

    dry_run = 5  # use 5 iterations to warm up
    for i in range(dry_run):
        mod.forward(batch, is_train=False)
        for output in mod.get_outputs():
            output.wait_to_read()

    tic = time.time()
    results = mod.score(eval_iter,
                        metric,
                        num_batch=None,
                        batch_end_callback=mx.callback.Speedometer(
                            batch_size, frequent=10, auto_reset=False))
    speed = num / (time.time() - tic)
    if logger is not None:
        logger.info('Finished inference with %d images' % num)
        logger.info('Finished with %f images per second', speed)

    for k, v in results:
        print("{}: {}".format(k, v))
Exemplo n.º 17
0
def train_net(net,
              train_path,
              num_classes,
              batch_size,
              data_shape,
              mean_pixels,
              resume,
              finetune,
              pretrained,
              epoch,
              prefix,
              ctx,
              begin_epoch,
              end_epoch,
              frequent,
              learning_rate,
              momentum,
              weight_decay,
              lr_refactor_step,
              lr_refactor_ratio,
              freeze_layer_pattern='',
              shape_range=(320, 512),
              random_shape_step=0,
              random_shape_epoch=10,
              num_example=10000,
              label_pad_width=350,
              nms_thresh=0.45,
              force_nms=False,
              ovp_thresh=0.5,
              use_difficult=False,
              class_names=None,
              voc07_metric=False,
              nms_topk=400,
              force_suppress=False,
              train_list="",
              val_path="",
              val_list="",
              iter_monitor=0,
              monitor_pattern=".*",
              log_file=None):
    """
    Wrapper for training phase.

    Parameters:
    ----------
    net : str
        symbol name for the network structure
    train_path : str
        record file path for training
    num_classes : int
        number of object classes, not including background
    batch_size : int
        training batch-size
    data_shape : int or tuple
        width/height as integer or (3, height, width) tuple
    mean_pixels : tuple of floats
        mean pixel values for red, green and blue
    resume : int
        resume from previous checkpoint if > 0
    finetune : int
        fine-tune from previous checkpoint if > 0
    pretrained : str
        prefix of pretrained model, including path
    epoch : int
        load epoch of either resume/finetune/pretrained model
    prefix : str
        prefix for saving checkpoints
    ctx : [mx.cpu()] or [mx.gpu(x)]
        list of mxnet contexts
    begin_epoch : int
        starting epoch for training, should be 0 if not otherwise specified
    end_epoch : int
        end epoch of training
    frequent : int
        frequency to print out training status
    learning_rate : float
        training learning rate
    momentum : float
        trainig momentum
    weight_decay : float
        training weight decay param
    lr_refactor_ratio : float
        multiplier for reducing learning rate
    lr_refactor_step : comma separated integers
        at which epoch to rescale learning rate, e.g. '30, 60, 90'
    freeze_layer_pattern : str
        regex pattern for layers need to be fixed
    shape_range : tuple of (min, max)
        random data shape range
    random_shape_step : int
        step size for random data shape, defined by network, 0 to disable
    random_step_epoch : int
        number of epoch before next random shape
    num_example : int
        number of training images
    label_pad_width : int
        force padding training and validation labels to sync their label widths
    nms_thresh : float
        non-maximum suppression threshold for validation
    force_nms : boolean
        suppress overlaped objects from different classes
    train_list : str
        list file path for training, this will replace the embeded labels in record
    val_path : str
        record file path for validation
    val_list : str
        list file path for validation, this will replace the embeded labels in record
    iter_monitor : int
        monitor internal stats in networks if > 0, specified by monitor_pattern
    monitor_pattern : str
        regex pattern for monitoring network stats
    log_file : str
        log to file if enabled
    """
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    if log_file:
        fh = logging.FileHandler(log_file)
        logger.addHandler(fh)

    # check args
    if isinstance(data_shape, int):
        data_shape = (3, data_shape, data_shape)
    assert len(data_shape) == 3 and data_shape[0] == 3
    prefix += '_' + net.strip('_yolo') + '_' + str(data_shape[1])

    if isinstance(mean_pixels, (int, float)):
        mean_pixels = [mean_pixels, mean_pixels, mean_pixels]
    assert len(mean_pixels) == 3, "must provide all RGB mean values"

    # load symbol
    sys.path.append(os.path.join(cfg.ROOT_DIR, 'symbol'))
    symbol_module = importlib.import_module("symbol_" + net)
    net = symbol_module.get_symbol(num_classes,
                                   nms_thresh=nms_thresh,
                                   force_suppress=force_suppress,
                                   nms_topk=nms_topk)

    # define layers with fixed weight/bias
    if freeze_layer_pattern.strip():
        re_prog = re.compile(freeze_layer_pattern)
        fixed_param_names = [
            name for name in net.list_arguments() if re_prog.match(name)
        ]
    else:
        fixed_param_names = None

    # load pretrained or resume from previous state
    ctx_str = '(' + ','.join([str(c) for c in ctx]) + ')'
    allow_missing = True
    if resume > 0:
        logger.info("Resume training with {} from epoch {}".format(
            ctx_str, resume))
        _, args, auxs = mx.model.load_checkpoint(prefix, resume)
        begin_epoch = resume
        allow_missing = False
    elif finetune > 0:
        logger.info("Start finetuning with {} from epoch {}".format(
            ctx_str, finetune))
        _, args, auxs = mx.model.load_checkpoint(prefix, finetune)
        begin_epoch = finetune
        # the prediction convolution layers name starts with relu, so it's fine
        fixed_param_names = [name for name in net.list_arguments() \
            if name.startswith('conv')]
    elif pretrained:
        logger.info("Start training with {} from pretrained model {}".format(
            ctx_str, pretrained))
        _, args, auxs = mx.model.load_checkpoint(pretrained, epoch)
        args = convert_pretrained(pretrained, args)
    else:
        logger.info("Experimental: start training from scratch with {}".format(
            ctx_str))
        args = None
        auxs = None
        fixed_param_names = None

    # helper information
    if fixed_param_names:
        logger.info("Freezed parameters: [" + ','.join(fixed_param_names) +
                    ']')

    # fit parameters
    batch_end_callback = mx.callback.Speedometer(batch_size, frequent=frequent)
    epoch_end_callback = mx.callback.do_checkpoint(prefix)
    monitor = mx.mon.Monitor(
        iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None

    # run fit net, every n epochs we run evaluation network to get mAP
    if voc07_metric:
        valid_metric = VOC07MApMetric(ovp_thresh,
                                      use_difficult,
                                      class_names,
                                      pred_idx=0)
    else:
        valid_metric = MApMetric(ovp_thresh,
                                 use_difficult,
                                 class_names,
                                 pred_idx=0)

    # init training module
    mod = mx.mod.Module(net,
                        label_names=('yolo_output_label', ),
                        logger=logger,
                        context=ctx,
                        fixed_param_names=fixed_param_names)

    random_shape_step = int(random_shape_step)
    if random_shape_step > 0:
        fit_begins = list(range(begin_epoch, end_epoch, random_shape_epoch))
        fit_ends = fit_begins[1:] + [end_epoch]
        assert (len(shape_range) == 2)
        data_shapes = [(3, x * random_shape_step, x * random_shape_step) \
            for x in range(shape_range[0] // random_shape_step,
            shape_range[1] // random_shape_step + 1)]
        logger.info("Candidate random shapes:" + str(data_shapes))
    else:
        fit_begins = [begin_epoch]
        fit_ends = [end_epoch]
        data_shapes = [data_shape]

    for begin, end in zip(fit_begins, fit_ends):
        if len(data_shapes) == 1:
            data_shape = data_shapes[0]
        else:
            data_shape = data_shapes[random.randint(0, len(data_shapes) - 1)]
            logger.info("Setting random data shape: " + str(data_shape))

        train_iter = DetRecordIter(train_path,
                                   batch_size,
                                   data_shape,
                                   mean_pixels=mean_pixels,
                                   label_pad_width=label_pad_width,
                                   path_imglist=train_list,
                                   **cfg.train)

        if val_path:
            val_iter = DetRecordIter(val_path,
                                     batch_size,
                                     data_shape,
                                     mean_pixels=mean_pixels,
                                     label_pad_width=label_pad_width,
                                     path_imglist=val_list,
                                     **cfg.valid)
        else:
            val_iter = None

        learning_rate, lr_scheduler = get_lr_scheduler(learning_rate,
                                                       lr_refactor_step,
                                                       lr_refactor_ratio,
                                                       num_example, batch_size,
                                                       begin_epoch)
        optimizer_params = {
            'learning_rate': learning_rate,
            'momentum': momentum,
            'wd': weight_decay,
            'lr_scheduler': lr_scheduler,
            'clip_gradient': 10,
            'rescale_grad': 1.0
        }

        mod.fit(train_iter,
                val_iter,
                eval_metric=MultiBoxMetric(),
                validation_metric=valid_metric,
                batch_end_callback=batch_end_callback,
                epoch_end_callback=epoch_end_callback,
                optimizer='sgd',
                optimizer_params=optimizer_params,
                begin_epoch=begin,
                num_epoch=end,
                initializer=mx.init.Xavier(),
                arg_params=args,
                aux_params=auxs,
                allow_missing=allow_missing,
                monitor=monitor,
                force_rebind=True,
                force_init=True)

        args, auxs = mod.get_params()
        allow_missing = False
Exemplo n.º 18
0
def train_net(net,
              train_path,
              num_classes,
              batch_size,
              data_shape,
              mean_pixels,
              resume,
              finetune,
              pretrained,
              epoch,
              prefix,
              ctx,
              begin_epoch,
              end_epoch,
              frequent,
              learning_rate,
              momentum,
              weight_decay,
              lr_refactor_step,
              lr_refactor_ratio,
              freeze_layer_pattern='',
              num_example=10000,
              label_pad_width=350,
              nms_thresh=0.45,
              force_nms=False,
              ovp_thresh=0.5,
              use_difficult=False,
              class_names=None,
              voc07_metric=False,
              nms_topk=2000,
              force_suppress=False,
              train_list="",
              val_path="",
              val_list="",
              iter_monitor=0,
              monitor_pattern=".*",
              log_file=None):
    """
    Wrapper for training phase.

    Parameters:
    ----------
    net : str
        symbol name for the network structure
    train_path : str
        record file path for training
    num_classes : int
        number of object classes, not including background
    batch_size : int
        training batch-size
    data_shape : int or tuple
        width/height as integer or (3, height, width) tuple
    mean_pixels : tuple of floats
        mean pixel values for red, green and blue
    resume : int
        resume from previous checkpoint if > 0
    finetune : int
        fine-tune from previous checkpoint if > 0
    pretrained : str
        prefix of pretrained model, including path
    epoch : int
        load epoch of either resume/finetune/pretrained model
    prefix : str
        prefix for saving checkpoints
    ctx : [mx.cpu()] or [mx.gpu(x)]
        list of mxnet contexts
    begin_epoch : int
        starting epoch for training, should be 0 if not otherwise specified
    end_epoch : int
        end epoch of training
    frequent : int
        frequency to print out training status
    learning_rate : float
        training learning rate
    momentum : float
        trainig momentum
    weight_decay : float
        training weight decay param
    lr_refactor_ratio : float
        multiplier for reducing learning rate
    lr_refactor_step : comma separated integers
        at which epoch to rescale learning rate, e.g. '30, 60, 90'
    freeze_layer_pattern : str
        regex pattern for layers need to be fixed
    num_example : int
        number of training images
    label_pad_width : int
        force padding training and validation labels to sync their label widths
    nms_thresh : float
        non-maximum suppression threshold for validation
    force_nms : boolean
        suppress overlaped objects from different classes
    train_list : str
        list file path for training, this will replace the embeded labels in record
    val_path : str
        record file path for validation
    val_list : str
        list file path for validation, this will replace the embeded labels in record
    iter_monitor : int
        monitor internal stats in networks if > 0, specified by monitor_pattern
    monitor_pattern : str
        regex pattern for monitoring network stats
    log_file : str
        log to file if enabled
    """
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    # set a log
    if log_file:
        # crate a fileHandler
        fh = logging.FileHandler(log_file)
        logger.addHandler(fh)

    # check args
    if isinstance(data_shape, int):
        data_shape = (3, data_shape, data_shape)
    assert len(data_shape) == 3 and data_shape[0] == 3
    prefix += '_' + net + '_' + str(data_shape[1])
    # check the mean_pixels is list
    if isinstance(mean_pixels, (int, float)):
        mean_pixels = [mean_pixels, mean_pixels, mean_pixels]
    assert len(mean_pixels) == 3, "must provide all RGB mean values"

    train_iter = DetRecordIter(train_path,
                               batch_size,
                               data_shape,
                               mean_pixels=mean_pixels,
                               label_pad_width=label_pad_width,
                               path_imglist=train_list,
                               **cfg.train)

    # for c in range(12840):
    #     batch = train_iter.next()
    #     data=batch.data[0]
    #     label=batch.label[0]
    #     from matplotlib import  pyplot as plt
    #     import numpy as np
    #     import cv2
    #     for i in range(2):
    #         plt.subplot(1,2,i+1)
    #         img = np.array(data[i].asnumpy().transpose(1,2,0).copy(), np.uint8)
    #         box = label[i].asnumpy()
    #         bbox = []
    #         print 'The', i, 'th image'
    #         for j in range(box.shape[0]):
    #             if box[j][0] == -1:
    #                 break
    #             else:
    #                 bbox.append(box[j][1:5])
    #         for k in range(len(bbox)):
    #             xmin = (bbox[k][0] * img.shape[0]).astype(np.int16)
    #             ymin = (bbox[k][1] * img.shape[0]).astype(np.int16)
    #             xmax = (bbox[k][2] * img.shape[0]).astype(np.int16)
    #             ymax = (bbox[k][3] * img.shape[0]).astype(np.int16)
    #             cv2.rectangle(img, (xmin,ymin), (xmax,ymax), (255,0,0),4)
    #
    #             print 'xmin', xmin, 'ymin', ymin, 'xmax', xmax, 'ymax', ymax
    #         plt.imshow(img)
    #     plt.show()
    #     #path = 'crop_image/'+ str(c) + '.jpg'
    #     #plt.savefig(path)
    #     print batch

    if val_path:
        val_iter = DetRecordIter(val_path,
                                 batch_size,
                                 data_shape,
                                 mean_pixels=mean_pixels,
                                 label_pad_width=label_pad_width,
                                 path_imglist=val_list,
                                 **cfg.valid)
    else:
        val_iter = None
    # load symbol
    net = get_symbol_train(net,
                           data_shape[1],
                           num_classes=num_classes,
                           nms_thresh=nms_thresh,
                           force_suppress=force_suppress,
                           nms_topk=nms_topk)
    # viz = mx.viz.plot_network(net)
    # viz.view()
    # define layers with fixed weight/bias
    if freeze_layer_pattern.strip():
        re_prog = re.compile(freeze_layer_pattern)
        fixed_param_names = [
            name for name in net.list_arguments() if re_prog.match(name)
        ]
    else:
        fixed_param_names = None

    # load pretrained or resume from previous state
    ctx_str = '(' + ','.join([str(c) for c in ctx]) + ')'
    if resume > 0:
        logger.info("Resume training with {} from epoch {}".format(
            ctx_str, resume))
        _, args, auxs = mx.model.load_checkpoint(prefix, resume)
        begin_epoch = resume
    elif finetune > 0:
        logger.info("Start finetuning with {} from epoch {}".format(
            ctx_str, finetune))
        _, args, auxs = mx.model.load_checkpoint(prefix, finetune)
        begin_epoch = finetune
        # the prediction convolution layers name starts with relu, so it's fine
        fixed_param_names = [name for name in net.list_arguments() \
            if name.startswith('conv')]
    elif pretrained:
        logger.info("Start training with {} from pretrained model {}".format(
            ctx_str, pretrained))
        fixed_param_names = None
        _, args, auxs = mx.model.load_checkpoint(pretrained, epoch)
        args = convert_pretrained(pretrained, args)
    else:
        logger.info("Experimental: start training from scratch with {}".format(
            ctx_str))
        args = None
        auxs = None
        fixed_param_names = None

    # helper information
    if fixed_param_names:
        logger.info("Freezed parameters: [" + ','.join(fixed_param_names) +
                    ']')

    # init training module
    mod = mx.mod.Module(net,
                        label_names=('label', ),
                        logger=logger,
                        context=ctx,
                        fixed_param_names=fixed_param_names)

    # fit parameters
    batch_end_callback = mx.callback.Speedometer(train_iter.batch_size,
                                                 frequent=frequent)
    epoch_end_callback = mx.callback.do_checkpoint(prefix)
    learning_rate, lr_scheduler = get_lr_scheduler(learning_rate,
                                                   lr_refactor_step,
                                                   lr_refactor_ratio,
                                                   num_example, batch_size,
                                                   begin_epoch)
    optimizer_params = {
        'learning_rate': learning_rate,
        'momentum': momentum,
        'wd': weight_decay,
        'lr_scheduler': lr_scheduler,
        'clip_gradient': None,
        'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0
    }
    monitor = mx.mon.Monitor(
        iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None

    # run fit net, every n epochs we run evaluation network to get mAP
    if voc07_metric:
        valid_metric = VOC07MApMetric(ovp_thresh,
                                      use_difficult,
                                      class_names,
                                      pred_idx=3)
    else:
        valid_metric = MApMetric(ovp_thresh,
                                 use_difficult,
                                 class_names,
                                 pred_idx=3)

    mod.fit(
        train_data=train_iter,  #train_iter,
        eval_data=val_iter,
        eval_metric=MultiBoxMetric(),
        validation_metric=valid_metric,
        batch_end_callback=batch_end_callback,
        epoch_end_callback=epoch_end_callback,
        optimizer='sgd',
        optimizer_params=optimizer_params,
        begin_epoch=begin_epoch,
        num_epoch=end_epoch,
        initializer=mx.init.Xavier(),
        arg_params=args,
        aux_params=auxs,
        allow_missing=True,
        monitor=monitor)