def __init__(self, symbol, ctx=None, begin_epoch=0, num_epoch=None, arg_params=None, aux_params=None, valid_metric=MApMetric(), class_names=[], optimizer='sgd', **kwargs): self.symbol = symbol if ctx is None: ctx = mx.cpu(0) self.ctx = ctx self.begin_epoch = begin_epoch self.num_epoch = num_epoch self.arg_params = arg_params self.aux_params = aux_params self.valid_metric = valid_metric self.class_names = class_names self.optimizer = optimizer self.evaluation_only = False self.kwargs = kwargs.copy()
def train_multitask(netname, train_path, num_classes, batch_size, data_shape, mean_pixels, resume, finetune, pretrained, epoch, prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate, momentum, weight_decay, lr_refactor_step, lr_refactor_ratio, freeze_layer_pattern='', num_example=10000, label_pad_width=350, nms_thresh=0.45, force_nms=False, ovp_thresh=0.5, use_difficult=False, class_names=None, voc07_metric=False, nms_topk=400, force_suppress=False, train_list="", val_path="", val_list="", iter_monitor=0, monitor_pattern=".*", log_file=None): """ Wrapper for training phase. Parameters: ---------- netname : str symbol name for the network structure train_path : str record file path for training num_classes : int number of object classes, not including background batch_size : int training batch-size data_shape : int or tuple width/height as integer or (3, height, width) tuple mean_pixels : tuple of floats mean pixel values for red, green and blue resume : int resume from previous checkpoint if > 0 finetune : int fine-tune from previous checkpoint if > 0 pretrained : str prefix of pretrained model, including path epoch : int load epoch of either resume/finetune/pretrained model prefix : str prefix for saving checkpoints ctx : [mx.cpu()] or [mx.gpu(x)] list of mxnet contexts begin_epoch : int starting epoch for training, should be 0 if not otherwise specified end_epoch : int end epoch of training frequent : int frequency to print out training status learning_rate : float training learning rate momentum : float trainig momentum weight_decay : float training weight decay param lr_refactor_ratio : float multiplier for reducing learning rate lr_refactor_step : comma separated integers at which epoch to rescale learning rate, e.g. '30, 60, 90' freeze_layer_pattern : str regex pattern for layers need to be fixed num_example : int number of training images label_pad_width : int force padding training and validation labels to sync their label widths nms_thresh : float non-maximum suppression threshold for validation force_nms : boolean suppress overlaped objects from different classes train_list : str list file path for training, this will replace the embeded labels in record val_path : str record file path for validation val_list : str list file path for validation, this will replace the embeded labels in record iter_monitor : int monitor internal stats in networks if > 0, specified by monitor_pattern monitor_pattern : str regex pattern for monitoring network stats log_file : str log to file if enabled """ # set up logger logger = logging.getLogger() fh = logging.FileHandler( os.path.join( 'log', time.strftime('%F-%T', time.localtime()).replace(':', '-') + '.log')) fh.setLevel(logging.DEBUG) ch = logging.StreamHandler() ch.setLevel(logging.INFO) logger.addHandler(fh) logger.addHandler(ch) # logging.basicConfig() # logger = logging.getLogger() # logger.setLevel(logging.INFO) # if log_file: # fh = logging.FileHandler(log_file) # logger.addHandler(fh) # check args if isinstance(data_shape, int): data_shape = (3, data_shape, data_shape) assert len(data_shape) == 3 and data_shape[0] == 3 prefix += '_' + netname + '_' + str(data_shape[1]) if isinstance(mean_pixels, (int, float)): mean_pixels = [mean_pixels, mean_pixels, mean_pixels] assert len(mean_pixels) == 3, "must provide all RGB mean values" logger.info( str({ "train_path": train_path, "batch_size": batch_size, "data_shape": data_shape })) train_iter = MultiTaskRecordIter(train_path, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=label_pad_width, path_imglist=train_list, enable_aug=True, **cfg.train) if val_path: val_iter = MultiTaskRecordIter(val_path, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=label_pad_width, path_imglist=val_list, enable_aug=False, **cfg.valid) else: val_iter = None # load symbol logger.info( str({ "num_classes": num_classes, "nms_thresh": nms_thresh, "force_suppress": force_suppress, "nms_topk": nms_topk })) if netname in ["resnet-18", "resnet-50"]: net = get_fcn32s_symbol_train(netname, data_shape[1], num_classes=num_classes, nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk) elif netname.endswith("det"): net = get_det_symbol_train(netname.split("_")[0], data_shape[1], num_classes=num_classes, nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk) elif netname.endswith("seg"): net = get_seg_symbol_train(netname.split("_")[0], data_shape[1], num_classes=num_classes, nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk) elif netname.endswith("multi"): net = get_multi_symbol_train(netname.split("_")[0], data_shape[1], num_classes=num_classes, nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk) else: raise NotImplementedError("") ################# analyze shapes ####################### # arg_shapes, out_shapes, aux_shapes = net.infer_shape(data=(1,3,512,1024), label_det=(1,200,6)) # arg_names = net.list_arguments() # print([(n,s) for n,s in zip(arg_names,arg_shapes)]) # define layers with fixed weight/bias if freeze_layer_pattern.strip(): re_prog = re.compile(freeze_layer_pattern) fixed_param_names = [ name for name in net.list_arguments() if re_prog.match(name) ] else: fixed_param_names = None # load pretrained or resume from previous state ctx_str = '(' + ','.join([str(c) for c in ctx]) + ')' ctx = ctx[0] if resume > 0: logger.info("Resume training with {} from epoch {}".format( ctx_str, resume)) _, args, auxs = mx.model.load_checkpoint(prefix, resume) begin_epoch = resume args = {key: val.as_in_context(ctx) for key, val in args.items()} auxs = {key: val.as_in_context(ctx) for key, val in auxs.items()} # elif finetune > 0: # logger.info("Start finetuning with {} from epoch {}".format(ctx_str, finetune)) # _, args, auxs = mx.model.load_checkpoint(prefix, finetune) # begin_epoch = finetune # # the prediction convolution layers name starts with relu, so it's fine # fixed_param_names = [name for name in net.list_arguments() if name.startswith('conv')] elif pretrained: logger.info("Start training with {} from pretrained model {}".format( ctx_str, pretrained)) _, args, auxs = mx.model.load_checkpoint(pretrained, epoch) # args = convert_pretrained(pretrained, args) args = {key: val.as_in_context(ctx) for key, val in args.items()} auxs = {key: val.as_in_context(ctx) for key, val in auxs.items()} args, auxs = init_from_resnet(ctx, net, args, auxs) else: logger.info("Experimental: start training from scratch with {}".format( ctx_str)) args, auxs, fixed_param_names = None, None, None # helper information if fixed_param_names: logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']') # init training module logger.info("Creating Module ...") mod = mx.mod.Module(net, label_names=( 'label_det', 'seg_out_label', ), logger=logger, context=ctx, fixed_param_names=fixed_param_names) # fit parameters batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent) epoch_end_callback = mx.callback.do_checkpoint(prefix) learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio, num_example, batch_size, begin_epoch) optimizer_params = { 'learning_rate': learning_rate, 'momentum': momentum, 'wd': weight_decay, 'lr_scheduler': lr_scheduler, 'clip_gradient': None, 'rescale_grad': 1.0 } #/ len(ctx) if len(ctx) > 0 else 1.0 } monitor = mx.mon.Monitor( iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None # run fit net, every n epochs we run evaluation network to get mAP valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=0) from pprint import pprint import numpy as np import cv2 from palette import color2index, index2color pprint(optimizer_params) np.set_printoptions(formatter={"float": lambda x: "%.3f " % x}, suppress=True) ############### uncomment the following lines to visualize network ########################### # dot = mx.viz.plot_network(net, shape={'data':(1,3,512,1024),"label_det":(1,200,6)}) # dot.view() ############### uncomment the following lines to visualize data ########################### # data_batch, _ = train_iter.next() # pprint({"data":data_batch.data[0].shape, # "label_det":data_batch.label[0].shape, # "seg_out_label":data_batch.label[1].shape}) # data = data_batch.data[0].asnumpy() # label = data_batch.label[0].asnumpy() # segmt = data_batch.label[1].asnumpy() # for ii in range(data.shape[0]): # img = data[ii,:,:,:] # seg = segmt[ii,:,:] # print label[ii,:5,:] # img = np.squeeze(img) # img = np.swapaxes(img, 0, 2) # img = np.swapaxes(img, 0, 1) # img = (img + np.array([123.68, 116.779, 103.939]).reshape((1,1,3))).astype(np.uint8) # img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # rois = label[ii,:,:] # hh, ww, ch = img.shape # for lidx in range(rois.shape[0]): # roi = rois[lidx,:] # if roi[0]>=0: # cv2.rectangle(img, (int(roi[1]*ww),int(roi[2]*hh)), (int(roi[3]*ww),int(roi[4]*hh)), (0,0,128)) # cls_id = int(roi[0]) # bbox = [int(roi[1]*ww),int(roi[2]*hh),int(roi[3]*ww),int(roi[4]*hh)] # text = '%s %.0fm' % (class_names[cls_id], roi[5]*255.) # putText(img,bbox,text) # disp = np.zeros((hh*2, ww, ch),np.uint8) # disp[:hh,:, :] = img.astype(np.uint8) # disp[hh:,:, :] = cv2.resize(index2color(seg),(ww,hh),interpolation=cv2.INTER_NEAREST) # cv2.imshow("img", disp) # if cv2.waitKey()&0xff==27: exit(0) # ctx=ctx[0] # args = {key: val.as_in_context(ctx) for key, val in args.items()} # auxs = {key: val.as_in_context(ctx) for key, val in auxs.items()} # args, auxs = init_from_resnet(ctx, net, args, auxs) pprint({"ctx":ctx,"begin_epoch":begin_epoch,"end_epoch":end_epoch, \ "learning_rate":learning_rate,"momentum":momentum}) model = None if netname.endswith("multi"): model = MultiTaskSolver( ctx=ctx, symbol=net, begin_epoch=begin_epoch, num_epoch=end_epoch, # 50 epoch arg_params=args, aux_params=auxs, learning_rate=learning_rate, # 1e-5 lr_scheduler=lr_scheduler, momentum=momentum, # 0.99 wd=0.0005, # 0.0005 valid_metric=valid_metric, class_names=class_names, ) elif netname.endswith("det"): model = DetTaskSolver( ctx=ctx, symbol=net, begin_epoch=begin_epoch, num_epoch=end_epoch, # 50 epoch arg_params=args, aux_params=auxs, learning_rate=learning_rate, # 1e-5 lr_scheduler=lr_scheduler, momentum=momentum, # 0.99 wd=0.0005, # 0.0005 valid_metric=valid_metric, class_names=class_names, ) elif netname.endswith("seg"): model = SegTaskSolver( ctx=ctx, symbol=net, begin_epoch=begin_epoch, num_epoch=end_epoch, # 50 epoch arg_params=args, aux_params=auxs, learning_rate=learning_rate, # 1e-5 lr_scheduler=lr_scheduler, momentum=momentum, # 0.99 wd=0.0005, # 0.0005 valid_metric=valid_metric, class_names=class_names, ) else: raise NotImplementedError("") model.fit(train_data=train_iter, eval_data=val_iter, batch_end_callback=batch_end_callback, epoch_end_callback=epoch_end_callback)
def train_net(net, train_path, num_classes, batch_size, data_shape, mean_pixels, resume, finetune, pretrained, epoch, prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate, momentum, weight_decay, lr_refactor_step, lr_refactor_ratio, freeze_layer_pattern='', num_example=10000, label_pad_width=350, nms_thresh=0.45, force_nms=False, ovp_thresh=0.5, use_difficult=False, class_names=None, voc07_metric=False, nms_topk=400, force_suppress=False, train_list="", val_path="", val_list="", iter_monitor=0, monitor_pattern=".*", log_file=None): """ Wrapper for training phase. Parameters: ---------- net : str symbol name for the network structure train_path : str record file path for training num_classes : int number of object classes, not including background batch_size : int training batch-size data_shape : int or tuple width/height as integer or (3, height, width) tuple mean_pixels : tuple of floats mean pixel values for red, green and blue resume : int resume from previous checkpoint if > 0 finetune : int fine-tune from previous checkpoint if > 0 pretrained : str prefix of pretrained model, including path epoch : int load epoch of either resume/finetune/pretrained model prefix : str prefix for saving checkpoints ctx : [mx.cpu()] or [mx.gpu(x)] list of mxnet contexts begin_epoch : int starting epoch for training, should be 0 if not otherwise specified end_epoch : int end epoch of training frequent : int frequency to print out training status learning_rate : float training learning rate momentum : float trainig momentum weight_decay : float training weight decay param lr_refactor_ratio : float multiplier for reducing learning rate lr_refactor_step : comma separated integers at which epoch to rescale learning rate, e.g. '30, 60, 90' freeze_layer_pattern : str regex pattern for layers need to be fixed num_example : int number of training images label_pad_width : int force padding training and validation labels to sync their label widths nms_thresh : float non-maximum suppression threshold for validation force_nms : boolean suppress overlaped objects from different classes train_list : str list file path for training, this will replace the embeded labels in record val_path : str record file path for validation val_list : str list file path for validation, this will replace the embeded labels in record iter_monitor : int monitor internal stats in networks if > 0, specified by monitor_pattern monitor_pattern : str regex pattern for monitoring network stats log_file : str log to file if enabled """ # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) if log_file: fh = logging.FileHandler(log_file) logger.addHandler(fh) # check args if isinstance(data_shape, int): data_shape = (3, data_shape, data_shape) assert len(data_shape) == 3 and data_shape[0] == 3 if prefix.endswith('_'): prefix += '_' + str(data_shape[1]) if isinstance(mean_pixels, (int, float)): mean_pixels = [mean_pixels, mean_pixels, mean_pixels] assert len(mean_pixels) == 3, "must provide all RGB mean values" train_iter = DetRecordIter(train_path, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=label_pad_width, path_imglist=train_list, **cfg.train) if val_path: val_iter = DetRecordIter(val_path, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=label_pad_width, path_imglist=val_list, **cfg.valid) else: val_iter = None # load symbol net = get_symbol_train(net, data_shape[1], num_classes=num_classes, nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk) # define layers with fixed weight/bias if freeze_layer_pattern.strip(): re_prog = re.compile(freeze_layer_pattern) fixed_param_names = [name for name in net.list_arguments() if re_prog.match(name)] else: fixed_param_names = None # load pretrained or resume from previous state ctx_str = '('+ ','.join([str(c) for c in ctx]) + ')' if resume > 0: logger.info("Resume training with {} from epoch {}" .format(ctx_str, resume)) _, args, auxs = mx.model.load_checkpoint(prefix, resume) begin_epoch = resume elif finetune > 0: logger.info("Start finetuning with {} from epoch {}" .format(ctx_str, finetune)) _, args, auxs = mx.model.load_checkpoint(prefix, finetune) begin_epoch = finetune # check what layers mismatch with the loaded parameters exe = net.simple_bind(mx.cpu(), data=(1, 3, 300, 300), label=(1, 1, 5), grad_req='null') arg_dict = exe.arg_dict fixed_param_names = [] for k, v in arg_dict.items(): if k in args: if v.shape != args[k].shape: del args[k] logging.info("Removed %s" % k) else: if not 'pred' in k: fixed_param_names.append(k) elif pretrained: logger.info("Start training with {} from pretrained model {}" .format(ctx_str, pretrained)) _, args, auxs = mx.model.load_checkpoint(pretrained, epoch) args = convert_pretrained(pretrained, args) else: logger.info("Experimental: start training from scratch with {}" .format(ctx_str)) args = None auxs = None fixed_param_names = None # helper information if fixed_param_names: logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']') # init training module mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx, fixed_param_names=fixed_param_names) # fit parameters batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent) epoch_end_callback = mx.callback.do_checkpoint(prefix) learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio, num_example, batch_size, begin_epoch) optimizer_params={'learning_rate':learning_rate, 'momentum':momentum, 'wd':weight_decay, 'lr_scheduler':lr_scheduler, 'clip_gradient':None, 'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0 } monitor = mx.mon.Monitor(iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None # run fit net, every n epochs we run evaluation network to get mAP if voc07_metric: valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3) else: valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3) mod.fit(train_iter, val_iter, eval_metric=MultiBoxMetric(), validation_metric=valid_metric, batch_end_callback=batch_end_callback, epoch_end_callback=epoch_end_callback, optimizer='sgd', optimizer_params=optimizer_params, begin_epoch=begin_epoch, num_epoch=end_epoch, initializer=mx.init.Xavier(), arg_params=args, aux_params=auxs, allow_missing=True, monitor=monitor)
def evaluate_net(net, path_imgrec, num_classes, mean_pixels, data_shape, model_prefix, epoch, ctx=mx.cpu(), batch_size=1, path_imglist="", nms_thresh=0.45, force_nms=False, ovp_thresh=0.5, use_difficult=False, class_names=None, voc07_metric=False): """ evalute network given validation record file Parameters: ---------- net : str or None Network name or use None to load from json without modifying path_imgrec : str path to the record validation file path_imglist : str path to the list file to replace labels in record file, optional num_classes : int number of classes, not including background mean_pixels : tuple (mean_r, mean_g, mean_b) data_shape : tuple or int (3, height, width) or height/width model_prefix : str model prefix of saved checkpoint epoch : int load model epoch ctx : mx.ctx mx.gpu() or mx.cpu() batch_size : int validation batch size nms_thresh : float non-maximum suppression threshold force_nms : boolean whether suppress different class objects ovp_thresh : float AP overlap threshold for true/false postives use_difficult : boolean whether to use difficult objects in evaluation if applicable class_names : comma separated str class names in string, must correspond to num_classes if set voc07_metric : boolean whether to use 11-point evluation as in VOC07 competition """ # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) # args if isinstance(data_shape, int): data_shape = (3, data_shape, data_shape) assert len(data_shape) == 3 and data_shape[0] == 3 model_prefix += '_' + str(data_shape[1]) # iterator eval_iter = DetRecordIter(path_imgrec, batch_size, data_shape, mean_pixels=mean_pixels, path_imglist=path_imglist, **cfg.valid) # model params load_net, args, auxs = mx.model.load_checkpoint(model_prefix, epoch) # network if net is None: net = load_net else: net = get_symbol(net, data_shape[1], num_classes=num_classes, nms_thresh=nms_thresh, force_suppress=force_nms) if not 'label' in net.list_arguments(): label = mx.sym.Variable(name='label') net = mx.sym.Group([net, label]) # init module mod = mx.mod.Module(net, label_names=('label', ), logger=logger, context=ctx, fixed_param_names=net.list_arguments()) mod.bind(data_shapes=eval_iter.provide_data, label_shapes=eval_iter.provide_label) mod.set_params(args, auxs, allow_missing=False, force_init=True) # run evaluation if voc07_metric: metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names) else: metric = MApMetric(ovp_thresh, use_difficult, class_names) results = mod.score(eval_iter, metric, num_batch=None) for k, v in results: print("{}: {}".format(k, v))
def evaluate_net(net, path_imgrec, num_classes, mean_pixels, data_shape, model_prefix, epoch, ctx=mx.cpu(), batch_size=1, path_imglist="", nms_thresh=0.45, force_nms=False, ovp_thresh=0.5, use_difficult=False, class_names=None, voc07_metric=False, frequent=20): """ evalute network given validation record file Parameters: ---------- net : str or None Network name or use None to load from json without modifying path_imgrec : str path to the record validation file path_imglist : str path to the list file to replace labels in record file, optional num_classes : int number of classes, not including background mean_pixels : tuple (mean_r, mean_g, mean_b) data_shape : tuple or int (3, height, width) or height/width model_prefix : str model prefix of saved checkpoint epoch : int load model epoch ctx : mx.ctx mx.gpu() or mx.cpu() batch_size : int validation batch size nms_thresh : float non-maximum suppression threshold force_nms : boolean whether suppress different class objects ovp_thresh : float AP overlap threshold for true/false postives use_difficult : boolean whether to use difficult objects in evaluation if applicable class_names : comma separated str class names in string, must correspond to num_classes if set voc07_metric : boolean whether to use 11-point evluation as in VOC07 competition frequent : int frequency to print out validation status """ # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) # args if isinstance(data_shape, int): data_shape = (3, data_shape, data_shape) assert len(data_shape) == 3 and data_shape[0] == 3 #model_prefix += '_' + str(data_shape[1]) # iterator eval_iter = DetRecordIter(path_imgrec, batch_size, data_shape, path_imglist=path_imglist, **cfg.valid) # model params load_net, args, auxs = mx.model.load_checkpoint(model_prefix, epoch) # network if net is None: net = load_net else: net = get_symbol(net, data_shape[1], num_classes=num_classes, nms_thresh=nms_thresh, force_suppress=force_nms) if not 'label' in net.list_arguments(): label = mx.sym.Variable(name='label') net = mx.sym.Group([net, label]) data_shape = (1, 3, 300, 300) mx.viz.plot_network_detail(net, shape={ "data": data_shape, "label": (1, 1, 5) }, node_attrs={ "hide_weights": "true", "fixedsize": 'false', "shape": 'oval' }).view() # Gang Chen add exe = net.simple_bind(mx.cpu(), data=(1, 3, 300, 300), label=(1, 1, 5), grad_req='null') arg_dict = exe.arg_dict for k, v in args.items(): if k not in arg_dict: del args[k] # END Gang Chen add # init module mod = mx.mod.Module(net, label_names=('label', ), logger=logger, context=ctx, fixed_param_names=net.list_arguments()) mod.bind(data_shapes=eval_iter.provide_data, label_shapes=eval_iter.provide_label) mod.set_params(args, auxs, allow_missing=True, force_init=True) # run evaluation if voc07_metric: metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, roc_output_path=os.path.join( os.path.dirname(model_prefix), 'roc')) else: metric = MApMetric(ovp_thresh, use_difficult, class_names, roc_output_path=os.path.join( os.path.dirname(model_prefix), 'roc')) results = mod.score(eval_iter, metric, num_batch=None, batch_end_callback=mx.callback.Speedometer( batch_size, frequent=frequent, auto_reset=False)) for k, v in results: print("{}: {}".format(k, v))
def train_net_common(net, train_iter, val_iter, batch_size, data_shape, resume, finetune, pretrained, epoch, prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate, momentum, weight_decay, use_plateau, lr_refactor_step, lr_refactor_ratio, freeze_layer_pattern='', num_example=10000, label_pad_width=350, nms_thresh=0.45, force_suppress=False, ovp_thresh=0.5, use_difficult=False, class_names=None, optimizer_name='sgd', voc07_metric=False, nms_topk=400, iter_monitor=0, monitor_pattern=".*", logger=None): """ """ # check args prefix += '_' + net + '_' + str(data_shape[1]) # load symbol net_str = net net = get_symbol_train(net, data_shape[1], num_classes=len(class_names), nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk) # define layers with fixed weight/bias if freeze_layer_pattern.strip(): re_prog = re.compile(freeze_layer_pattern) fixed_param_names = [name for name in net.list_arguments() if re_prog.match(name)] else: fixed_param_names = None # load pretrained or resume from previous state ctx_str = '('+ ','.join([str(c) for c in ctx]) + ')' if resume > 0: logger.info("Resume training with {} from epoch {}" .format(ctx_str, resume)) _, args, auxs = mx.model.load_checkpoint(prefix, resume) begin_epoch = resume elif finetune > 0: logger.info("Start finetuning with {} from epoch {}" .format(ctx_str, finetune)) _, args, auxs = mx.model.load_checkpoint(prefix, finetune) begin_epoch = finetune # the prediction convolution layers name starts with relu, so it's fine fixed_param_names = [name for name in net.list_arguments() \ if name.startswith('conv')] elif pretrained: try: logger.info("Start training with {} from pretrained model {}" .format(ctx_str, pretrained)) _, args, auxs = mx.model.load_checkpoint(pretrained, epoch) args = convert_pretrained(pretrained, args) if net_str == 'ssd_pva': args, auxs = convert_pvanet(args, auxs) except: logger.info("Failed to load the pretrained model. Start from scratch.") args = None auxs = None fixed_param_names = None else: logger.info("Experimental: start training from scratch with {}" .format(ctx_str)) args = None auxs = None fixed_param_names = None # helper information if fixed_param_names: logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']') # init training module if not use_plateau: # focal loss does not go well with plateau mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx, fixed_param_names=fixed_param_names) else: mod = PlateauModule(net, label_names=('label',), logger=logger, context=ctx, fixed_param_names=fixed_param_names) # robust parameter setting mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label) mod = set_mod_params(mod, args, auxs, logger) # fit parameters batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent, auto_reset=True) epoch_end_callback = mx.callback.do_checkpoint(prefix) monitor = mx.mon.Monitor(iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None optimizer_params={'learning_rate': learning_rate, 'wd': weight_decay, 'clip_gradient': 4.0, 'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0 } if optimizer_name == 'sgd': optimizer_params['momentum'] = momentum # #7847 mod.init_optimizer(optimizer=optimizer_name, optimizer_params=optimizer_params, force_init=True) if not use_plateau: learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio, num_example, batch_size, begin_epoch) else: w_l1 = cfg.train['smoothl1_weight'] eval_weights = {'CrossEntropy': 1.0, 'SmoothL1': w_l1} plateau_lr = PlateauScheduler( \ patient_epochs=lr_refactor_step, factor=float(lr_refactor_ratio), eval_weights=eval_weights) plateau_metric = MultiBoxMetric(fn_stat='/home/hyunjoon/github/additions_mxnet/ssd/stat.txt') eval_metric = MultiBoxMetric() # run fit net, every n epochs we run evaluation network to get mAP if voc07_metric: valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=4) else: valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=4) if not use_plateau: mod.fit(train_iter, eval_data=val_iter, eval_metric=eval_metric, validation_metric=valid_metric, batch_end_callback=batch_end_callback, epoch_end_callback=epoch_end_callback, optimizer=optimizer_name, optimizer_params=optimizer_params, begin_epoch=begin_epoch, num_epoch=end_epoch, initializer=mx.init.Xavier(), arg_params=args, aux_params=auxs, allow_missing=True, monitor=monitor) else: mod.fit(train_iter, plateau_lr, plateau_metric=plateau_metric, fn_curr_model=prefix+'-1000.params', plateau_backtrace=False, eval_data=val_iter, eval_metric=eval_metric, validation_metric=valid_metric, validation_period=5, batch_end_callback=batch_end_callback, epoch_end_callback=epoch_end_callback, optimizer=optimizer_name, optimizer_params=optimizer_params, begin_epoch=begin_epoch, num_epoch=end_epoch, initializer=mx.init.Xavier(), arg_params=args, aux_params=auxs, allow_missing=True, monitor=monitor)
def train_net(net, train_path, num_classes, batch_size, data_shape, mean_img, mean_img_dir, resume, finetune, pretrained, epoch, prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate, momentum, weight_decay, lr_refactor_step, lr_refactor_ratio, convert_numpy=1, freeze_layer_pattern='', num_example=10000, label_pad_width=350, nms_thresh=0.45, force_nms=False, ovp_thresh=0.5, use_difficult=False, class_names=None, voc07_metric=False, nms_topk=400, force_suppress=False, train_list="", val_path="", val_list="", iter_monitor=0, monitor_pattern=".*", log_file=None, summarywriter=0, flush_secs=180): """ Wrapper for training phase. Parameters: ---------- net : str symbol name for the network structure train_path : str record file path for training num_classes : int number of object classes, not including background batch_size : int training batch-size data_shape : int or tuple width/height as integer or (3, height, width) tuple mean_pixels : tuple of floats mean pixel values for red, green and blue resume : int resume from previous checkpoint if > 0 finetune : int fine-tune from previous checkpoint if > 0 pretrained : str prefix of pretrained model, including path epoch : int load epoch of either resume/finetune/pretrained model prefix : str prefix for saving checkpoints ctx : [mx.cpu()] or [mx.gpu(x)] list of mxnet contexts begin_epoch : int starting epoch for training, should be 0 if not otherwise specified end_epoch : int end epoch of training frequent : int frequency to print out training status learning_rate : float training learning rate momentum : float trainig momentum weight_decay : float training weight decay param lr_refactor_ratio : float multiplier for reducing learning rate lr_refactor_step : comma separated integers at which epoch to rescale learning rate, e.g. '30, 60, 90' freeze_layer_pattern : str regex pattern for layers need to be fixed num_example : int number of training images label_pad_width : int force padding training and validation labels to sync their label widths nms_thresh : float non-maximum suppression threshold for validation force_nms : boolean suppress overlaped objects from different classes train_list : str list file path for training, this will replace the embeded labels in record val_path : str record file path for validation val_list : str list file path for validation, this will replace the embeded labels in record iter_monitor : int monitor internal stats in networks if > 0, specified by monitor_pattern monitor_pattern : str regex pattern for monitoring network stats log_file : str log to file if enabled """ # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) if log_file: fh = logging.FileHandler(log_file) logger.addHandler(fh) # check args if isinstance(data_shape, int): data_shape = (3, data_shape, data_shape) assert len(data_shape) == 3 and data_shape[0] == 3 prefix += '_' + net + '_' + str(data_shape[1]) # if isinstance(mean_pixels, (int, float)): # mean_pixels = [mean_pixels, mean_pixels, mean_pixels] # assert len(mean_pixels) == 3, "must provide all RGB mean values" train_iter = DetRecordIter(train_path, batch_size, data_shape, mean_img=mean_img, label_pad_width=label_pad_width, path_imglist=train_list, **cfg.train) if val_path: val_iter = DetRecordIter(val_path, batch_size, data_shape, mean_img=mean_img, label_pad_width=label_pad_width, path_imglist=val_list, **cfg.valid) else: val_iter = None # convert mean.bin to mean.npy _convert_mean_numpy(convert_numpy, mean_img_dir, mean_img) # load symbol net = get_symbol_train(net, data_shape[1], num_classes=num_classes, nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk) if summarywriter: if os.path.exists('/opt/incubator-mxnet/example/ssd/logs'): shutil.rmtree('/opt/incubator-mxnet/example/ssd/logs' ) # clear the previous logs os.mkdir('/opt/incubator-mxnet/example/ssd/logs') sw = SummaryWriter(logdir='/opt/incubator-mxnet/example/ssd/logs', flush_secs=flush_secs) sw.add_graph(net) else: sw = None # mx.viz.plot_network(net, shape={"data":(64, 3, 320, 320)}, node_attrs={"shape":'rect',"fixedsize":'false'}).view() # define layers with fixed weight/bias if freeze_layer_pattern.strip(): re_prog = re.compile(freeze_layer_pattern) fixed_param_names = [ name for name in net.list_arguments() if re_prog.match(name) ] else: fixed_param_names = None # load pretrained or resume from previous state ctx_str = '(' + ','.join([str(c) for c in ctx]) + ')' if resume > 0: logger.info("Resume training with {} from epoch {}".format( ctx_str, resume)) _, args, auxs = mx.model.load_checkpoint(prefix, resume) begin_epoch = resume elif finetune > 0: logger.info("Start finetuning with {} from epoch {}".format( ctx_str, finetune)) _, args, auxs = mx.model.load_checkpoint(prefix, finetune) begin_epoch = finetune # the prediction convolution layers name starts with relu, so it's fine fixed_param_names = [name for name in net.list_arguments() \ if name.startswith('conv')] elif pretrained: logger.info("Start training with {} from pretrained model {}".format( ctx_str, pretrained)) _, args, auxs = mx.model.load_checkpoint(pretrained, epoch) args = convert_pretrained(pretrained, args) else: logger.info("Experimental: start training from scratch with {}".format( ctx_str)) args = None auxs = None fixed_param_names = None # helper information if fixed_param_names: logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']') # init training module mod = mx.mod.Module(net, label_names=('label', ), logger=logger, context=ctx, fixed_param_names=fixed_param_names) # fit parameters if summarywriter: # 增加可视化的回调函数,有多个回调函数时,除最后一个回调函数外不能进行准确率的清零操作(即auto_reset参数必须设置为False) batch_end_callbacks = [ mx.callback.Speedometer(train_iter.batch_size, frequent=frequent, auto_reset=True), summary_writter_callback.summary_writter_eval_metric(sw) ] else: batch_end_callbacks = [ mx.callback.Speedometer(train_iter.batch_size, frequent=frequent, auto_reset=False) ] # batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent) epoch_end_callback = mx.callback.do_checkpoint(prefix) learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio, num_example, batch_size, begin_epoch) optimizer_params = { 'learning_rate': learning_rate, 'momentum': momentum, 'wd': weight_decay, 'lr_scheduler': lr_scheduler, 'clip_gradient': None, 'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0 } monitor = mx.mon.Monitor( iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None # run fit net, every n epochs we run evaluation network to get mAP if voc07_metric: valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3) else: valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3) mod.fit(train_iter, val_iter, eval_metric=MultiBoxMetric(), validation_metric=valid_metric, batch_end_callback=batch_end_callbacks, epoch_end_callback=epoch_end_callback, optimizer='sgd', optimizer_params=optimizer_params, begin_epoch=begin_epoch, num_epoch=end_epoch, initializer=mx.init.Xavier(), arg_params=args, aux_params=auxs, allow_missing=True, monitor=monitor) if summarywriter: sw.close()
def evaluate(netname, path_imgrec, num_classes, num_seg_classes, mean_pixels, data_shape, model_prefix, epoch, ctx=mx.cpu(), batch_size=1, path_imglist="", nms_thresh=0.45, force_nms=False, ovp_thresh=0.5, use_difficult=False, class_names=None, seg_class_names=None, voc07_metric=False): """ evalute network given validation record file Parameters: ---------- net : str or None Network name or use None to load from json without modifying path_imgrec : str path to the record validation file path_imglist : str path to the list file to replace labels in record file, optional num_classes : int number of classes, not including background mean_pixels : tuple (mean_r, mean_g, mean_b) data_shape : tuple or int (3, height, width) or height/width model_prefix : str model prefix of saved checkpoint epoch : int load model epoch ctx : mx.ctx mx.gpu() or mx.cpu() batch_size : int validation batch size nms_thresh : float non-maximum suppression threshold force_nms : boolean whether suppress different class objects ovp_thresh : float AP overlap threshold for true/false postives use_difficult : boolean whether to use difficult objects in evaluation if applicable class_names : comma separated str class names in string, must correspond to num_classes if set voc07_metric : boolean whether to use 11-point evluation as in VOC07 competition """ global outimgiter # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) # args if isinstance(data_shape, int): data_shape = (3, data_shape, data_shape) else: data_shape = map(int, data_shape.split(",")) assert len(data_shape) == 3 and data_shape[0] == 3 model_prefix += '_' + str(data_shape[1]) # iterator eval_iter = MultiTaskRecordIter(path_imgrec, batch_size, data_shape, path_imglist=path_imglist, enable_aug=False, **cfg.valid) # model params load_net, args, auxs = mx.model.load_checkpoint(model_prefix, epoch) # network if netname is None: net = load_net elif netname.endswith("det"): net = get_det_symbol(netname.split("_")[0], data_shape[1], num_classes=num_classes, nms_thresh=nms_thresh, force_suppress=force_nms) elif netname.endswith("seg"): net = get_seg_symbol(netname.split("_")[0], data_shape[1], num_classes=num_classes, nms_thresh=nms_thresh, force_suppress=force_nms) elif netname.endswith("multi"): net = get_multi_symbol(netname.split("_")[0], data_shape[1], num_classes=num_classes, nms_thresh=nms_thresh, force_suppress=force_nms) else: raise NotImplementedError("") if not 'label_det' in net.list_arguments(): label_det = mx.sym.Variable(name='label_det') net = mx.sym.Group([net, label_det]) if not 'seg_out_label' in net.list_arguments(): seg_out_label = mx.sym.Variable(name='seg_out_label') net = mx.sym.Group([net, seg_out_label]) # init module # mod = mx.mod.Module(net, label_names=('label_det','seg_out_label',), logger=logger, context=ctx, # fixed_param_names=net.list_arguments()) # mod.bind(data_shapes=eval_iter.provide_data, label_shapes=eval_iter.provide_label) # mod.set_params(args, auxs, allow_missing=False, force_init=True) # metric = MApMetric(ovp_thresh, use_difficult, class_names) # results = mod.score(eval_iter, metric, num_batch=None) # for k, v in results: # print("{}: {}".format(k, v)) ctx = ctx[0] eval_metric = CustomAccuracyMetric() multibox_metric = MultiBoxMetric() depth_metric = DistanceAccuracyMetric(class_names=class_names) det_metric = MApMetric(ovp_thresh, use_difficult, class_names) seg_metric = IoUMetric(class_names=seg_class_names, axis=1) eval_metrics = metric.CompositeEvalMetric() eval_metrics.add(multibox_metric) eval_metrics.add(eval_metric) arg_params = {key: val.as_in_context(ctx) for key, val in args.items()} aux_params = {key: val.as_in_context(ctx) for key, val in auxs.items()} data_name = eval_iter.provide_data[0][0] label_name_det = eval_iter.provide_label[0][0] label_name_seg = eval_iter.provide_label[1][0] symbol = load_net # evaluation logger.info(" in eval process...") logger.info( str({ "ovp_thresh": ovp_thresh, "nms_thresh": nms_thresh, "batch_size": batch_size, "force_nms": force_nms, })) nbatch = 0 eval_iter.reset() eval_metrics.reset() det_metric.reset() total_time = 0 for data, fnames in eval_iter: nbatch += 1 label_shape_det = data.label[0].shape label_shape_seg = data.label[1].shape arg_params[data_name] = mx.nd.array(data.data[0], ctx) arg_params[label_name_det] = mx.nd.array(data.label[0], ctx) arg_params[label_name_seg] = mx.nd.array(data.label[1], ctx) executor = symbol.bind(ctx, arg_params, aux_states=aux_params) output_names = symbol.list_outputs() output_dict = dict(zip(output_names, executor.outputs)) cpu_output_array = mx.nd.zeros(output_dict["seg_out_output"].shape) ############## monitor status def stat_helper(name, array): """wrapper for executor callback""" import ctypes from mxnet.ndarray import NDArray from mxnet.base import NDArrayHandle, py_str array = ctypes.cast(array, NDArrayHandle) if 1: array = NDArray(array, writable=False).asnumpy() print(name, array.shape, np.mean(array), np.std(array), ('%.1fms' % (float(time.time() - stat_helper.start_time) * 1000))) else: array = NDArray(array, writable=False) array.wait_to_read() elapsed = float(time.time() - stat_helper.start_time) * 1000. if elapsed > 5: print(name, array.shape, ('%.1fms' % (elapsed, ))) stat_helper.start_time = time.time() stat_helper.start_time = float(time.time()) # executor.set_monitor_callback(stat_helper) ############## forward tic = time.time() executor.forward(is_train=True) output_dict["seg_out_output"].copyto(cpu_output_array) pred_shape = output_dict["seg_out_output"].shape label = mx.nd.array(data.label[1].reshape( (label_shape_seg[0], label_shape_seg[1] * label_shape_seg[2]))) output_dict["seg_out_output"].wait_to_read() toc = time.time() seg_out_output = output_dict["seg_out_output"].asnumpy() pred_seg_shape = output_dict["seg_out_output"].shape label_det = mx.nd.array(data.label[0].reshape( (label_shape_det[0], label_shape_det[1] * label_shape_det[2]))) label_seg = mx.nd.array(data.label[1].reshape( (label_shape_seg[0], label_shape_seg[1] * label_shape_seg[2])), ctx=ctx) pred_seg = mx.nd.array(output_dict["seg_out_output"].reshape( (pred_seg_shape[0], pred_seg_shape[1], pred_seg_shape[2] * pred_seg_shape[3])), ctx=ctx) #### remove invalid boxes out_det = output_dict["det_out_output"].asnumpy() indices = np.where(out_det[:, :, 0] >= 0) # labeled as negative out_det = np.expand_dims(out_det[indices[0], indices[1], :], axis=0) indices = np.where(out_det[:, :, 1] > .1) # higher confidence out_det = np.expand_dims(out_det[indices[0], indices[1], :], axis=0) # indices = np.where(out_det[:,:,6]<=(100/255.)) # too far away # out_det = np.expand_dims(out_det[indices[0],indices[1],:],axis=0) pred_det = mx.nd.array(out_det) #### remove labels too faraway # label_det = label_det.asnumpy().reshape((200,6)) # indices = np.where(label_det[:,5]<=(100./255.)) # label_det = np.expand_dims(label_det[indices[0],:],axis=0) # label_det = mx.nd.array(label_det) ################# display results #################### out_img = output_dict["seg_out_output"] out_img = mx.nd.split(out_img, axis=0, num_outputs=out_img.shape[0]) for imgidx in range(batch_size): seg_prob = out_img[imgidx] res_img = np.squeeze(seg_prob.asnumpy().argmax(axis=0).astype( np.uint8)) label_img = data.label[1].asnumpy()[imgidx, :, :].astype(np.uint8) img = np.squeeze(data.data[0].asnumpy()[imgidx, :, :, :]) det = out_det[imgidx, :, :] gt = label_det.asnumpy()[imgidx, :].reshape((-1, 6)) # save to results folder for evalutation res_fname = fnames[imgidx].replace("SegmentationClass", "results") lut = np.zeros(256) lut[:19] = np.array([ 7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33 ]) seg_resized = prob_upsampling(seg_prob, target_shape=(1024, 2048)) seg_resized2 = cv2.LUT(seg_resized, lut) # seg = cv2.LUT(res_img,lut) # cv2.imshow("seg",seg.astype(np.uint8)) cv2.imwrite(res_fname, seg_resized2) # display result print(fnames[imgidx], np.average(img)) display_img = display_results(res_img, np.expand_dims(label_img, axis=0), img, det, gt, class_names) res_fname = fnames[imgidx].replace("SegmentationClass", "output").replace( "labelTrainIds", "output") cv2.imwrite(res_fname, display_img) [exit(0) if (cv2.waitKey() & 0xff) == 27 else None] outimgiter += 1 ################# display results #################### eval_metrics.get_metric(0).update(None, [ output_dict["cls_prob_output"], output_dict["loc_loss_output"], output_dict["cls_label_output"] ]) eval_metrics.get_metric(1).update([label_seg], [pred_seg]) det_metric.update([mx.nd.slice_axis(data.label[0],axis=2,begin=0,end=5)], \ [mx.nd.slice_axis(pred_det,axis=2,begin=0,end=6)]) seg_metric.update([label_seg], [pred_seg]) disparities = [] for imgidx in range(batch_size): dispname = fnames[imgidx].replace("SegmentationClass", "Disparity").replace( "gtFine_labelTrainIds", "disparity") print(dispname) disparities.append(cv2.imread(dispname, -1)) depth_metric.update(mx.nd.array(disparities), [pred_det]) det_names, det_values = det_metric.get() seg_names, seg_values = seg_metric.get() depth_names, depth_values = depth_metric.get() total_time += toc - tic print("\r %d/%d %.1f%% speed=%.1fms %s=%.1f %s=%.1f %s=%.1f" % ( nbatch * eval_iter.batch_size, eval_iter.num_samples, float(nbatch * eval_iter.batch_size) * 100. / float(eval_iter.num_samples), total_time * 1000. / nbatch, det_names[-1], det_values[-1] * 100., seg_names[-1], seg_values[-1] * 100., depth_names[-1], depth_values[-1] * 100., ), end='\r') # if nbatch>50: break ## debugging names, values = eval_metrics.get() for name, value in zip(names, values): logger.info(' epoch[%d] Validation-%s=%f', epoch, name, value) logger.info('----------------------------------------------') names, values = det_metric.get() for name, value in zip(names, values): logger.info(' epoch[%d] Validation-%s=%f', epoch, name, value) logger.info('----------------------------------------------') logger.info(' & '.join(names)) logger.info(' & '.join(map(lambda v: '%.1f' % (v * 100., ), values))) logger.info('----------------------------------------------') names, values = depth_metric.get() for name, value in zip(names, values): logger.info(' epoch[%d] Validation-%s=%f', epoch, name, value) logger.info('----------------------------------------------') logger.info(' & '.join(names)) logger.info(' & '.join(map(lambda v: '%.1f' % (v * 100., ), values))) logger.info('----------------------------------------------') names, values = seg_metric.get() for name, value in zip(names, values): logger.info(' epoch[%d] Validation-%s=%f', epoch, name, value) logger.info('----------------------------------------------') logger.info(' & '.join(names)) logger.info(' & '.join(map(lambda v: '%.1f' % (v * 100., ), values)))
def train_net(net, train_path, num_classes, batch_size, data_shape, mean_pixels, resume, finetune, pretrained, epoch, prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate, momentum, weight_decay, use_plateau, lr_refactor_step, lr_refactor_ratio, use_global_stats=0, freeze_layer_pattern='', num_example=10000, label_pad_width=350, nms_thresh=0.45, force_nms=False, ovp_thresh=0.5, use_difficult=False, class_names=None, ignore_names=None, optimizer_name='sgd', voc07_metric=False, nms_topk=400, force_suppress=False, train_list="", val_path="", val_list="", iter_monitor=0, monitor_pattern=".*", log_file=None): """ Wrapper for training phase. Parameters: ---------- net : str symbol name for the network structure train_path : str record file path for training num_classes : int number of object classes, not including background batch_size : int training batch-size data_shape : int or tuple width/height as integer or (3, height, width) tuple mean_pixels : tuple of floats mean pixel values for red, green and blue resume : int resume from previous checkpoint if > 0 finetune : int fine-tune from previous checkpoint if > 0 pretrained : str prefix of pretrained model, including path epoch : int load epoch of either resume/finetune/pretrained model prefix : str prefix for saving checkpoints ctx : [mx.cpu()] or [mx.gpu(x)] list of mxnet contexts begin_epoch : int starting epoch for training, should be 0 if not otherwise specified end_epoch : int end epoch of training frequent : int frequency to print out training status learning_rate : float training learning rate momentum : float trainig momentum weight_decay : float training weight decay param lr_refactor_ratio : float multiplier for reducing learning rate lr_refactor_step : comma separated integers at which epoch to rescale learning rate, e.g. '30, 60, 90' freeze_layer_pattern : str regex pattern for layers need to be fixed num_example : int number of training images label_pad_width : int force padding training and validation labels to sync their label widths nms_thresh : float non-maximum suppression threshold for validation force_nms : boolean suppress overlaped objects from different classes train_list : str list file path for training, this will replace the embeded labels in record val_path : str record file path for validation val_list : str list file path for validation, this will replace the embeded labels in record iter_monitor : int monitor internal stats in networks if > 0, specified by monitor_pattern monitor_pattern : str regex pattern for monitoring network stats log_file : str log to file if enabled """ # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) if log_file: fh = logging.FileHandler(log_file) logger.addHandler(fh) # check args if isinstance(data_shape, int): data_shape = (3, data_shape, data_shape) assert len(data_shape) == 3 and data_shape[0] == 3 prefix += '_' + net + '_' + str(data_shape[1]) if isinstance(mean_pixels, (int, float)): mean_pixels = [mean_pixels, mean_pixels, mean_pixels] assert len(mean_pixels) == 3, "must provide all RGB mean values" train_iter = DetRecordIter(train_path, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=label_pad_width, path_imglist=train_list, **cfg.train) if val_path: val_iter = DetRecordIter(val_path, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=label_pad_width, path_imglist=val_list, **cfg.valid) else: val_iter = None # load symbol net_str = net net = get_symbol_train(net, data_shape[1], \ use_global_stats=use_global_stats, \ num_classes=num_classes, ignore_names=ignore_names, \ nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk) # define layers with fixed weight/bias if freeze_layer_pattern.strip(): re_prog = re.compile(freeze_layer_pattern) fixed_param_names = [ name for name in net.list_arguments() if re_prog.match(name) ] else: fixed_param_names = None # load pretrained or resume from previous state ctx_str = '(' + ','.join([str(c) for c in ctx]) + ')' if resume > 0: logger.info("Resume training with {} from epoch {}".format( ctx_str, resume)) _, args, auxs = mx.model.load_checkpoint(prefix, resume) begin_epoch = resume elif finetune > 0: logger.info("Start finetuning with {} from epoch {}".format( ctx_str, finetune)) _, args, auxs = mx.model.load_checkpoint(prefix, finetune) begin_epoch = finetune # the prediction convolution layers name starts with relu, so it's fine fixed_param_names = [name for name in net.list_arguments() \ if name.startswith('conv')] elif pretrained: try: logger.info( "Start training with {} from pretrained model {}".format( ctx_str, pretrained)) _, args, auxs = mx.model.load_checkpoint(pretrained, epoch) args = convert_pretrained(pretrained, args) if net_str == 'ssd_pva': args, auxs = convert_pvanet(args, auxs) except: logger.info( "Failed to load the pretrained model. Start from scratch.") args = None auxs = None fixed_param_names = None else: logger.info("Experimental: start training from scratch with {}".format( ctx_str)) args = None auxs = None fixed_param_names = None # helper information if fixed_param_names: logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']') # init training module if not use_plateau: # focal loss does not go well with plateau mod = mx.mod.Module(net, label_names=('label', ), logger=logger, context=ctx, fixed_param_names=fixed_param_names) else: mod = PlateauModule(net, label_names=('label', ), logger=logger, context=ctx, fixed_param_names=fixed_param_names) # robust parameter setting mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label) mod = set_mod_params(mod, args, auxs, logger) # fit parameters batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent, auto_reset=True) epoch_end_callback = mx.callback.do_checkpoint(prefix) monitor = mx.mon.Monitor( iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None optimizer_params = { 'learning_rate': learning_rate, 'wd': weight_decay, 'clip_gradient': 4.0, 'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0 } if optimizer_name == 'sgd': optimizer_params['momentum'] = momentum # #7847 mod.init_optimizer(optimizer=optimizer_name, optimizer_params=optimizer_params, force_init=True) if not use_plateau: learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio, num_example, batch_size, begin_epoch) else: w_l1 = cfg.train['smoothl1_weight'] eval_weights = { 'CrossEntropy': 1.0, 'SmoothL1': w_l1, 'ObjectRecall': 0.0 } plateau_lr = PlateauScheduler( \ patient_epochs=lr_refactor_step, factor=float(lr_refactor_ratio), eval_weights=eval_weights) plateau_metric = MultiBoxMetric( fn_stat='/home/hyunjoon/github/additions_mxnet/ssd/stat.txt') mod.init_optimizer(optimizer=optimizer_name, optimizer_params=optimizer_params) eval_metric = MultiBoxMetric() # run fit net, every n epochs we run evaluation network to get mAP if voc07_metric: map_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=4) recall_metric = RecallMetric(ovp_thresh, use_difficult, pred_idx=4) valid_metric = mx.metric.create([map_metric, recall_metric]) else: valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=4) if not use_plateau: mod.fit(train_iter, eval_data=val_iter, eval_metric=eval_metric, validation_metric=valid_metric, batch_end_callback=batch_end_callback, epoch_end_callback=epoch_end_callback, optimizer=optimizer_name, optimizer_params=optimizer_params, begin_epoch=begin_epoch, num_epoch=end_epoch, initializer=mx.init.Xavier(), arg_params=args, aux_params=auxs, allow_missing=True, monitor=monitor) else: mod.fit(train_iter, plateau_lr, plateau_metric=plateau_metric, fn_curr_model=prefix + '-1000.params', plateau_backtrace=False, eval_data=val_iter, eval_metric=eval_metric, validation_metric=valid_metric, validation_period=5, kvstore='local', batch_end_callback=batch_end_callback, epoch_end_callback=epoch_end_callback, optimizer=optimizer_name, optimizer_params=optimizer_params, begin_epoch=begin_epoch, num_epoch=end_epoch, initializer=mx.init.Xavier(), arg_params=args, aux_params=auxs, allow_missing=True, monitor=monitor)
def evaluate_net(net, path_imgrec, num_classes, mean_pixels, data_shape, model_prefix, epoch, ctx=mx.cpu(), batch_size=1, path_imglist="", nms_thresh=0.45, force_nms=False, ovp_thresh=0.5, use_difficult=False, class_names=None, voc07_metric=False, use_second_network=False, net1=None, path_imgrec1=None, epoch1=None, model_prefix1=None, data_shape1=None): """ evalute network given validation record file Parameters: ---------- net : str or None Network name or use None to load from json without modifying path_imgrec : str path to the record validation file path_imglist : str path to the list file to replace labels in record file, optional num_classes : int number of classes, not including background mean_pixels : tuple (mean_r, mean_g, mean_b) data_shape : tuple or int (3, height, width) or height/width model_prefix : str model prefix of saved checkpoint epoch : int load model epoch ctx : mx.ctx mx.gpu() or mx.cpu() batch_size : int validation batch size nms_thresh : float non-maximum suppression threshold force_nms : boolean whether suppress different class objects ovp_thresh : float AP overlap threshold for true/false postives use_difficult : boolean whether to use difficult objects in evaluation if applicable class_names : comma separated str class names in string, must correspond to num_classes if set voc07_metric : boolean whether to use 11-point evluation as in VOC07 competition """ # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) # args if isinstance(data_shape, int): data_shape = (3, data_shape, data_shape) elif isinstance(data_shape, list): data_shape = (3, data_shape[0], data_shape[1]) assert len(data_shape) == 3 and data_shape[0] == 3 # model_prefix += '_' + str(data_shape[1]) # iterator #eval_iter = DetRecordIter(path_imgrec, batch_size, data_shape, # path_imglist=path_imglist, **cfg.valid) curr_path = os.path.abspath(os.path.dirname(__file__)) imdb_val = load_caltech(image_set='val', caltech_path=os.path.join( curr_path, '..', 'data', 'caltech-pedestrian-dataset-converter'), shuffle=False) eval_iter = DetIter(imdb_val, batch_size, (data_shape[1], data_shape[2]), \ mean_pixels=[128, 128, 128], rand_samplers=[], \ rand_mirror=False, shuffle=False, rand_seed=None, \ is_train=True, max_crop_trial=50) # model params load_net, args, auxs = mx.model.load_checkpoint(model_prefix, epoch) # network if net is None: net = load_net else: #net = get_symbol(net, data_shape[1], num_classes=num_classes, net = get_symbol_concat(net, data_shape[1], num_classes=num_classes, nms_thresh=nms_thresh, force_suppress=force_nms) if not 'label' in net.list_arguments(): label = mx.sym.Variable(name='label') label2 = mx.sym.Variable(name='label2') net = mx.sym.Group([net, label, label2]) # init module #mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx, mod = mx.mod.Module(net, label_names=('label', 'label2'), logger=logger, context=ctx, fixed_param_names=net.list_arguments()) mod.bind(data_shapes=eval_iter.provide_data, label_shapes=eval_iter.provide_label) mod.set_params(args, auxs, allow_missing=False, force_init=True) if voc07_metric: #metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=1) metric = VOC07MApMetric( ovp_thresh, use_difficult, class_names, pred_idx=[0, 1], output_names=['detection_output', 'detection2_output'], label_names=['label', 'label2']) else: #metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=1) metric = MApMetric( ovp_thresh, use_difficult, class_names, pred_idx=[0, 1], output_names=['detection_output', 'detection2_output'], label_names=['label', 'label2']) # run evaluation if not use_second_network: results = mod.score(eval_iter, metric, num_batch=None) for k, v in results: print("{}: {}".format(k, v)) else: logging.basicConfig() logger1 = logging.getLogger() logger1.setLevel(logging.INFO) # load sub network if isinstance(data_shape1, int): data_shape1 = (3, data_shape1, data_shape1) elif isinstance(data_shape1, list): data_shape1 = (3, data_shape1[0], data_shape1[1]) assert len(data_shape1) == 3 and data_shape1[0] == 3 # iterator eval_iter1 = DetRecordIter(path_imgrec1, batch_size, data_shape1, path_imglist=path_imglist, **cfg.valid) # model params load_net1, args1, auxs1 = mx.model.load_checkpoint( model_prefix1, epoch1) # network if net1 is None: net1 = load_net1 else: net1 = net if 'label' not in net1.list_arguments(): label1 = mx.sym.Variable(name='label') net1 = mx.sym.Group([net1, label1]) # init module mod1 = mx.mod.Module(net1, label_names=('label', ), logger=logger1, context=ctx, fixed_param_names=net1.list_arguments()) mod1.bind(data_shapes=eval_iter1.provide_data, label_shapes=eval_iter1.provide_label) mod1.set_params(args1, auxs1, allow_missing=False, force_init=True) if voc07_metric: metric1 = VOC07MApMetric(ovp_thresh, use_difficult, class_names) else: metric1 = MApMetric(ovp_thresh, use_difficult, class_names) # filepath = '/home/binghao/workspace/MXNet-SSD/matlab/kitti/outputs/ssd/' filepath1 = '/home/binghao/workspace/MXNet-SSD/matlab/kitti/outputs/ssd_small/' # mod.score_m(filepath, eval_iter, metric, num_batch=None) mod1.score_m(filepath1, eval_iter1, metric1, num_batch=None)
def evaluate_net(net, imdb, mean_pixels, data_shape, model_prefix, epoch, ctx=mx.cpu(), batch_size=1, nms_thresh=0.45, force_nms=False, ovp_thresh=0.5, use_difficult=False, voc07_metric=False): """ evalute network given validation record file Parameters: ---------- net : str or None Network name or use None to load from json without modifying path_imgrec : str path to the record validation file path_imglist : str path to the list file to replace labels in record file, optional num_classes : int number of classes, not including background mean_pixels : tuple (mean_r, mean_g, mean_b) data_shape : tuple or int (3, height, width) or height/width model_prefix : str model prefix of saved checkpoint epoch : int load model epoch ctx : mx.ctx mx.gpu() or mx.cpu() batch_size : int validation batch size nms_thresh : float non-maximum suppression threshold force_nms : boolean whether suppress different class objects ovp_thresh : float AP overlap threshold for true/false postives use_difficult : boolean whether to use difficult objects in evaluation if applicable class_names : comma separated str class names in string, must correspond to num_classes if set voc07_metric : boolean whether to use 11-point evluation as in VOC07 competition """ # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) num_classes = imdb.num_classes class_names = imdb.classes # args if isinstance(data_shape, int): data_shape = (3, data_shape, data_shape) assert len(data_shape) == 3 and data_shape[0] == 3 model_prefix += '_' + str(data_shape[1]) # iterator eval_iter = FaceTestIter(imdb, mean_pixels, img_stride=128, fix_hw=True) # model params load_net, args, auxs = mx.model.load_checkpoint(model_prefix, epoch) # network if net is None: net = load_net else: net = get_symbol(net, data_shape[1], num_classes=num_classes, nms_thresh=nms_thresh, force_suppress=force_nms) if not 'label' in net.list_arguments(): label = mx.sym.Variable(name='label') net = mx.sym.Group([net, label]) # init module mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx, fixed_param_names=net.list_arguments()) mod.bind(data_shapes=eval_iter.provide_data, label_shapes=eval_iter.provide_label) mod.set_params(args, auxs, allow_missing=False, force_init=True) # run evaluation if voc07_metric: metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names) else: metric = MApMetric(ovp_thresh, use_difficult, class_names) results = [] for i, (datum, im_info) in enumerate(eval_iter): mod.reshape(data_shapes=datum.provide_data, label_shapes=datum.provide_label) mod.forward(datum) preds = mod.get_outputs() det0 = preds[0][0].asnumpy() # (n_anchor, 6) det0 = do_nms(det0, 1, nms_thresh) preds[0][0] = mx.nd.array(det0, ctx=preds[0].context) sy, sx, _ = im_info['im_shape'] scaler = mx.nd.array((1.0, sx, sy, sx, sy, 1.0)) scaler = mx.nd.reshape(scaler, (1, 1, -1)) datum.label[0] *= scaler metric.update(datum.label, preds) if i % 10 == 0: print('processed {} images.'.format(i)) # if i == 10: # break results = metric.get_name_value() for k, v in results: print("{}: {}".format(k, v))
def evaluate_net(net, path_imgrec, num_classes, mean_pixels, data_shape, model_prefix, epoch, ctx=mx.cpu(), batch_size=1, path_imglist="", nms_thresh=0.45, force_nms=False, ovp_thresh=0.5, use_difficult=False, class_names=None, voc07_metric=False, frequent=20): """ evalute network given validation record file Parameters: ---------- net : str or None Network name or use None to load from json without modifying path_imgrec : str path to the record validation file path_imglist : str path to the list file to replace labels in record file, optional num_classes : int number of classes, not including background mean_pixels : tuple (mean_r, mean_g, mean_b) data_shape : tuple or int (3, height, width) or height/width model_prefix : str model prefix of saved checkpoint epoch : int load model epoch ctx : mx.ctx mx.gpu() or mx.cpu() batch_size : int validation batch size nms_thresh : float non-maximum suppression threshold force_nms : boolean whether suppress different class objects ovp_thresh : float AP overlap threshold for true/false postives use_difficult : boolean whether to use difficult objects in evaluation if applicable class_names : comma separated str class names in string, must correspond to num_classes if set voc07_metric : boolean whether to use 11-point evluation as in VOC07 competition frequent : int frequency to print out validation status """ # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) # args if isinstance(data_shape, int): data_shape = (3, data_shape, data_shape) assert len(data_shape) == 3 and data_shape[0] == 3 #model_prefix += '_' + str(data_shape[1]) # iterator eval_iter = DetRecordIter(path_imgrec, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=350, path_imglist=path_imglist, **cfg.valid) # model params load_net, args, auxs = mx.model.load_checkpoint(model_prefix, epoch) # network if net is None: net = load_net else: net = get_symbol(net, data_shape[1], num_classes=num_classes, nms_thresh=nms_thresh, force_suppress=force_nms) if not 'label' in net.list_arguments(): label = mx.sym.Variable(name='label') net = mx.sym.Group([net, label]) # init module mod = mx.mod.Module(net, label_names=('label', ), logger=logger, context=ctx, fixed_param_names=net.list_arguments()) mod.bind(data_shapes=eval_iter.provide_data, label_shapes=eval_iter.provide_label) mod.set_params(args, auxs, allow_missing=False, force_init=True) # run evaluation if voc07_metric: metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, roc_output_path=os.path.join( os.path.dirname(model_prefix), 'roc')) else: metric = MApMetric(ovp_thresh, use_difficult, class_names, roc_output_path=os.path.join( os.path.dirname(model_prefix), 'roc')) posemetric = PoseMetric( LINEMOD_path='/data/ZHANGXIN/DATASETS/SIXD_CHALLENGE/LINEMOD/', classes=class_names) # visualize bb8 results # for nbatch, eval_batch in tqdm(enumerate(eval_iter)): # mod.forward(eval_batch) # preds = mod.get_outputs(merge_multi_context=True) # # labels = eval_batch.label[0].asnumpy() # # get generated multi label from network # cls_prob = preds[0] # loc_pred = preds[4] # bb8_pred = preds[5] # anchors = preds[6] # # bb8dets = BB8MultiBoxDetection(cls_prob, loc_pred, bb8_pred, anchors, nms_threshold=0.5, force_suppress=False, # variances=(0.1, 0.1, 0.2, 0.2), nms_topk=400) # bb8dets = bb8dets.asnumpy() # # for nsample, sampleDet in enumerate(bb8dets): # image = eval_batch.data[0][nsample].asnumpy() # image += np.array(mean_pixels).reshape((3, 1, 1)) # image = np.transpose(image, axes=(1, 2, 0)) # draw_dets = [] # draw_cids = [] # # for instanceDet in sampleDet: # if instanceDet[0] == -1: # continue # else: # cid = instanceDet[0].astype(np.int16) # indices = np.where(sampleDet[:, 0] == cid)[0] # # if indices.size > 0: # draw_dets.append(sampleDet[indices[0], 6:]) # draw_cids.append(cid) # sampleDet = np.delete(sampleDet, indices, axis=0) # show_BB8(image / 255., np.transpose(draw_dets[-1].reshape((-1, 8, 2)), axes=(0,2,1)), [cid], # plot_path='./output/bb8results/{:04d}_{}'.format(nbatch * batch_size + nsample, class_names[cid])) # # # draw_dets = np.array(draw_dets) # # draw_cids = np.array(draw_cids) # # # show_BB8(image / 255., np.transpose(draw_dets.reshape((-1, 8, 2)), axes=(0,2,1)), draw_cids, # # plot_path='./output/bb8results/{:04d}'.format(nbatch * batch_size + nsample)) # quantitive results results = mod.score(eval_iter, [metric, posemetric], num_batch=None, batch_end_callback=mx.callback.Speedometer( batch_size, frequent=frequent, auto_reset=False)) results_save_path = os.path.join(os.path.dirname(model_prefix), 'evaluate_results') with open(results_save_path, 'w') as f: for k, v in results: print("{}: {}".format(k, v)) f.write("{}: {}\n".format(k, v)) f.close() reproj_save_path = os.path.join(os.path.dirname(model_prefix), 'reprojection_error') with open(reproj_save_path, 'wb') as f: # for k, v in metric.Reproj.items(): # f.write("{}: {}\n".format(k, v)) pickle.dump(posemetric.Reproj, f, protocol=2) f.close() count_save_path = os.path.join(os.path.dirname(model_prefix), 'gt_count') with open(count_save_path, 'wb') as f: # for k, v in metric.counts.items(): # f.write("{}: {}\n".format(k, v)) pickle.dump(posemetric.counts, f, protocol=2) f.close()
def train_net(net, dataset, image_set, devkit_path, batch_size, data_shape, mean_pixels, resume, finetune, pretrained, epoch, prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate, momentum, weight_decay, lr_refactor_step, lr_refactor_ratio, year='', val_image_set=None, val_year='', freeze_layer_pattern='', label_pad_width=350, nms_thresh=0.45, force_nms=False, ovp_thresh=0.5, use_difficult=False, voc07_metric=False, nms_topk=400, force_suppress=False, iter_monitor=0, monitor_pattern=".*", log_file=None): """ Wrapper for training phase. Parameters: ---------- net : str symbol name for the network structure dataset : str pascal_voc, imagenet... image_set : str train, trainval... devkit_path : str root directory of dataset batch_size : int training batch-size data_shape : int or tuple width/height as integer or (3, height, width) tuple mean_pixels : tuple of floats mean pixel values for red, green and blue resume : int resume from previous checkpoint if > 0 finetune : int fine-tune from previous checkpoint if > 0 pretrained : str prefix of pretrained model, including path epoch : int load epoch of either resume/finetune/pretrained model prefix : str prefix for saving checkpoints ctx : [mx.cpu()] or [mx.gpu(x)] list of mxnet contexts begin_epoch : int starting epoch for training, should be 0 if not otherwise specified end_epoch : int end epoch of training frequent : int frequency to print out training status learning_rate : float training learning rate momentum : float trainig momentum weight_decay : float training weight decay param lr_refactor_ratio : float multiplier for reducing learning rate lr_refactor_step : comma separated integers at which epoch to rescale learning rate, e.g. '30, 60, 90' year : str 2007, 2012 or combinations splitted by comma freeze_layer_pattern : str regex pattern for layers need to be fixed num_example : int number of training images label_pad_width : int force padding training and validation labels to sync their label widths nms_thresh : float non-maximum suppression threshold for validation force_nms : boolean suppress overlaped objects from different classes iter_monitor : int monitor internal stats in networks if > 0, specified by monitor_pattern monitor_pattern : str regex pattern for monitoring network stats log_file : str log to file if enabled """ # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) if log_file: fh = logging.FileHandler(log_file) logger.addHandler(fh) import ipdb ipdb.set_trace() # check args if isinstance(data_shape, int): data_shape = (3, data_shape, data_shape) assert len(data_shape) == 3 and data_shape[0] == 3 prefix += '_' + str(data_shape[1]) if isinstance(mean_pixels, (int, float)): mean_pixels = [mean_pixels, mean_pixels, mean_pixels] assert len(mean_pixels) == 3, "must provide all RGB mean values" # load dataset if dataset == 'pascal_voc': imdb = load_pascal(image_set, year, devkit_path, cfg.train['init_shuffle']) if val_image_set and val_image_set != '' and val_year: val_imdb = load_pascal(val_image_set, val_year, devkit_path, False) else: val_imdb = None else: raise NotImplementedError("Dataset " + dataset + " not supported") rand_scaler = RandScaler(min_scale=cfg.train['min_aug_scale'], max_scale=cfg.train['max_aug_scale'], min_gt_scale=cfg.train['min_aug_gt_scale'], max_trials=cfg.train['max_aug_trials'], max_sample=cfg.train['max_aug_sample'], patch_size=cfg.train['aug_patch_size']) # init data iterator train_iter = DetIter(imdb, batch_size, data_shape, mean_pixels, rand_scaler, cfg.train['rand_mirror'], cfg.train['epoch_shuffle'], cfg.train['seed'], is_train=True) # TODO: enable val_iter val_iter = None # if val_imdb: # val_iter = DetIter(val_imdb, batch_size, data_shape, mean_pixels, # cfg.valid.rand_scaler, cfg.valid.rand_mirror, # cfg.valid.epoch_shuffle, cfg.valid.seed, # is_train=True) # else: # val_iter = None # load symbol sys.path.append(os.path.join(cfg.ROOT_DIR, 'symbol')) symbol_module = importlib.import_module("symbol_" + net) net = symbol_module.get_symbol_train(imdb.num_classes, nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk) # define layers with fixed weight/bias if freeze_layer_pattern.strip(): re_prog = re.compile(freeze_layer_pattern) fixed_param_names = [ name for name in net.list_arguments() if re_prog.match(name) ] else: fixed_param_names = None # load pretrained or resume from previous state ctx_str = '(' + ','.join([str(c) for c in ctx]) + ')' if resume > 0: logger.info("Resume training with {} from epoch {}".format( ctx_str, resume)) _, args, auxs = mx.model.load_checkpoint(prefix, resume) begin_epoch = resume elif finetune > 0: logger.info("Start finetuning with {} from epoch {}".format( ctx_str, finetune)) _, args, auxs = mx.model.load_checkpoint(prefix, finetune) begin_epoch = finetune # the prediction convolution layers name starts with relu, so it's fine fixed_param_names = [name for name in net.list_arguments() \ if name.startswith('conv')] elif pretrained: logger.info("Start training with {} from pretrained model {}".format( ctx_str, pretrained)) _, args, auxs = mx.model.load_checkpoint(pretrained, epoch) args = convert_pretrained(pretrained, args) else: logger.info("Experimental: start training from scratch with {}".format( ctx_str)) args = None auxs = None fixed_param_names = None # helper information if fixed_param_names: logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']') # init training module mod = mx.mod.Module(net, label_names=('label', ), logger=logger, context=ctx, fixed_param_names=fixed_param_names) # fit parameters batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent) epoch_end_callback = mx.callback.do_checkpoint(prefix) learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio, num_example, batch_size, begin_epoch) optimizer_params = { 'learning_rate': learning_rate, 'wd': weight_decay, 'clip_gradient': 10.0, 'rescale_grad': 1.0, 'lr_scheduler': lr_scheduler } # optimizer_params={'learning_rate':learning_rate, # 'momentum':momentum, # 'wd':weight_decay, # 'lr_scheduler':lr_scheduler, # 'clip_gradient':None, # 'rescale_grad': 1.0} monitor = mx.mon.Monitor( iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None # run fit net, every n epochs we run evaluation network to get mAP if voc07_metric: valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, imdb.classes, pred_idx=3) else: valid_metric = MApMetric(ovp_thresh, use_difficult, imdb.classes, pred_idx=3) mod.fit(train_iter, val_iter, eval_metric=MultiBoxMetric(), validation_metric=valid_metric, batch_end_callback=batch_end_callback, epoch_end_callback=epoch_end_callback, optimizer='adam', optimizer_params=optimizer_params, begin_epoch=begin_epoch, num_epoch=end_epoch, initializer=mx.init.Xavier(), arg_params=args, aux_params=auxs, allow_missing=True, monitor=monitor)
def train_multitask(net, train_path, num_classes, batch_size, data_shape, mean_pixels, resume, finetune, pretrained, epoch, prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate, momentum, weight_decay, lr_refactor_step, lr_refactor_ratio, freeze_layer_pattern='', num_example=10000, label_pad_width=350, nms_thresh=0.45, force_nms=False, ovp_thresh=0.5, use_difficult=False, class_names=None, voc07_metric=False, nms_topk=400, force_suppress=False, train_list="", val_path="", val_list="", iter_monitor=0, monitor_pattern=".*", log_file=None): """ Wrapper for training phase. Parameters: ---------- net : str symbol name for the network structure train_path : str record file path for training num_classes : int number of object classes, not including background batch_size : int training batch-size data_shape : int or tuple width/height as integer or (3, height, width) tuple mean_pixels : tuple of floats mean pixel values for red, green and blue resume : int resume from previous checkpoint if > 0 finetune : int fine-tune from previous checkpoint if > 0 pretrained : str prefix of pretrained model, including path epoch : int load epoch of either resume/finetune/pretrained model prefix : str prefix for saving checkpoints ctx : [mx.cpu()] or [mx.gpu(x)] list of mxnet contexts begin_epoch : int starting epoch for training, should be 0 if not otherwise specified end_epoch : int end epoch of training frequent : int frequency to print out training status learning_rate : float training learning rate momentum : float trainig momentum weight_decay : float training weight decay param lr_refactor_ratio : float multiplier for reducing learning rate lr_refactor_step : comma separated integers at which epoch to rescale learning rate, e.g. '30, 60, 90' freeze_layer_pattern : str regex pattern for layers need to be fixed num_example : int number of training images label_pad_width : int force padding training and validation labels to sync their label widths nms_thresh : float non-maximum suppression threshold for validation force_nms : boolean suppress overlaped objects from different classes train_list : str list file path for training, this will replace the embeded labels in record val_path : str record file path for validation val_list : str list file path for validation, this will replace the embeded labels in record iter_monitor : int monitor internal stats in networks if > 0, specified by monitor_pattern monitor_pattern : str regex pattern for monitoring network stats log_file : str log to file if enabled """ # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) if log_file: fh = logging.FileHandler(log_file) logger.addHandler(fh) # check args if isinstance(data_shape, int): data_shape = (3, data_shape, data_shape) assert len(data_shape) == 3 and data_shape[0] == 3 prefix += '_' + net + '_' + str(data_shape[1]) if isinstance(mean_pixels, (int, float)): mean_pixels = [mean_pixels, mean_pixels, mean_pixels] assert len(mean_pixels) == 3, "must provide all RGB mean values" logger.info(str({"train_path":train_path,"batch_size":batch_size,"data_shape":data_shape})) train_iter = MultiTaskRecordIter(train_path, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=label_pad_width, path_imglist=train_list, **cfg.train) if val_path: val_iter = MultiTaskRecordIter(val_path, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=label_pad_width, path_imglist=val_list, **cfg.valid) else: val_iter = None # load symbol net = get_symbol_train(net, data_shape[1], num_classes=num_classes, nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk) # define layers with fixed weight/bias if freeze_layer_pattern.strip(): re_prog = re.compile(freeze_layer_pattern) fixed_param_names = [name for name in net.list_arguments() if re_prog.match(name)] else: fixed_param_names = None # load pretrained or resume from previous state ctx_str = '('+ ','.join([str(c) for c in ctx]) + ')' if resume > 0: logger.info("Resume training with {} from epoch {}" .format(ctx_str, resume)) _, args, auxs = mx.model.load_checkpoint(prefix, resume) begin_epoch = resume elif finetune > 0: logger.info("Start finetuning with {} from epoch {}" .format(ctx_str, finetune)) _, args, auxs = mx.model.load_checkpoint(prefix, finetune) begin_epoch = finetune # the prediction convolution layers name starts with relu, so it's fine fixed_param_names = [name for name in net.list_arguments() \ if name.startswith('conv')] elif pretrained: logger.info("Start training with {} from pretrained model {}" .format(ctx_str, pretrained)) _, args, auxs = mx.model.load_checkpoint(pretrained, epoch) args = convert_pretrained(pretrained, args) else: logger.info("Experimental: start training from scratch with {}" .format(ctx_str)) args = None auxs = None fixed_param_names = None # helper information if fixed_param_names: logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']') # init training module logger.info("Creating Base Module ...") mod = mx.mod.Module(net, label_names=('label_det','seg_out_label',), logger=logger, context=ctx, fixed_param_names=fixed_param_names) # fit parameters batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent) epoch_end_callback = mx.callback.do_checkpoint(prefix) learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio, num_example, batch_size, begin_epoch) optimizer_params={'learning_rate':learning_rate, 'momentum':momentum, 'wd':weight_decay, 'lr_scheduler':lr_scheduler, 'clip_gradient':None, 'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0 } monitor = mx.mon.Monitor(iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None # run fit net, every n epochs we run evaluation network to get mAP if voc07_metric: valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3) else: valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3) from pprint import pprint import numpy as np import cv2 from palette import color2index, index2color pprint(optimizer_params) np.set_printoptions(formatter={"float":lambda x:"%.3f "%x},suppress=True) ############### uncomment the following lines to visualize network ########################### internals = net.get_internals() print(net) # print(internals) ############### uncomment the following lines to visualize data ########################### data_batch = train_iter.next() pprint({"data":data_batch.data[0].shape, "label_det":data_batch.label[0].shape, "seg_out_label":data_batch.label[1].shape}) data = data_batch.data[0].asnumpy() label = data_batch.label[0].asnumpy() segmt = data_batch.label[1].asnumpy() for ii in range(data.shape[0]): img = data[ii,:,:,:] seg = segmt[ii,:,:] print label[ii,:5,:] img = np.squeeze(img) img = np.swapaxes(img, 0, 2) img = np.swapaxes(img, 0, 1) img = (img + np.array([123.68, 116.779, 103.939]).reshape((1,1,3))).astype(np.uint8) img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) rois = label[ii,:,:] hh, ww, ch = img.shape for lidx in range(rois.shape[0]): roi = rois[lidx,:] if roi[0]>=0: cv2.rectangle(img, (int(roi[1]*ww),int(roi[2]*hh)), (int(roi[3]*ww),int(roi[4]*hh)), (0,0,128)) cls_id = int(roi[0]) bbox = [int(roi[1]*ww),int(roi[2]*hh),int(roi[3]*ww),int(roi[4]*hh)] text = '%s %.0fm' % (class_names[cls_id], roi[5]*255.) putText(img,bbox,text) disp = np.zeros((hh*2, ww, ch),np.uint8) disp[:hh,:, :] = img.astype(np.uint8) disp[hh:,:, :] = index2color(seg) cv2.imshow("img", disp) if cv2.waitKey()&0xff==27: exit(0) mod.fit(train_iter, val_iter, eval_metric=MultiBoxMetric(), validation_metric=valid_metric, batch_end_callback=batch_end_callback, epoch_end_callback=epoch_end_callback, optimizer='sgd', optimizer_params=optimizer_params, begin_epoch=begin_epoch, num_epoch=end_epoch, initializer=mx.init.Xavier(), arg_params=args, aux_params=auxs, allow_missing=True, monitor=monitor)
def train_net(network, train_path, num_classes, batch_size, data_shape, mean_pixels, resume, finetune, pretrained, epoch, prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate, momentum, weight_decay, lr_refactor_step, lr_refactor_ratio, freeze_layer_pattern='', num_example=10000, label_pad_width=350, nms_thresh=0.45, force_nms=False, ovp_thresh=0.5, use_difficult=False, class_names=None, voc07_metric=False, nms_topk=400, force_suppress=False, train_list="", val_path="", val_list="", iter_monitor=0, monitor_pattern=".*", log_file=None, optimizer='sgd', tensorboard=False, checkpoint_period=5, min_neg_samples=0): """ Wrapper for training phase. Parameters: ---------- net : str symbol name for the network structure train_path : str record file path for training num_classes : int number of object classes, not including background batch_size : int training batch-size data_shape : int or tuple width/height as integer or (3, height, width) tuple mean_pixels : tuple of floats mean pixel values for red, green and blue resume : int resume from previous checkpoint if > 0 finetune : int fine-tune from previous checkpoint if > 0 pretrained : str prefix of pretrained model, including path epoch : int load epoch of either resume/finetune/pretrained model prefix : str prefix for saving checkpoints ctx : [mx.cpu()] or [mx.gpu(x)] list of mxnet contexts begin_epoch : int starting epoch for training, should be 0 if not otherwise specified end_epoch : int end epoch of training frequent : int frequency to print out training status optimizer : str usage of different optimizers, other then default sgd learning_rate : float training learning rate momentum : float trainig momentum weight_decay : float training weight decay param lr_refactor_ratio : float multiplier for reducing learning rate lr_refactor_step : comma separated integers at which epoch to rescale learning rate, e.g. '30, 60, 90' freeze_layer_pattern : str regex pattern for layers need to be fixed num_example : int number of training images label_pad_width : int force padding training and validation labels to sync their label widths nms_thresh : float non-maximum suppression threshold for validation force_nms : boolean suppress overlaped objects from different classes train_list : str list file path for training, this will replace the embeded labels in record val_path : str record file path for validation val_list : str list file path for validation, this will replace the embeded labels in record iter_monitor : int monitor internal stats in networks if > 0, specified by monitor_pattern monitor_pattern : str regex pattern for monitoring network stats log_file : str log to file if enabled tensorboard : bool record logs into tensorboard min_neg_samples : int always have some negative examples, no matter how many positive there are. this is useful when training on images with no ground-truth. checkpoint_period : int a checkpoint will be saved every "checkpoint_period" epochs """ # check actual number of train_images if os.path.exists(train_path.replace('rec', 'idx')): with open(train_path.replace('rec', 'idx'), 'r') as f: txt = f.readlines() num_example = len(txt) # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) if log_file: log_file_path = os.path.join(os.path.dirname(prefix), log_file) if not os.path.exists(os.path.dirname(log_file_path)): os.makedirs(os.path.dirname(log_file_path)) fh = logging.FileHandler(log_file_path) logger.addHandler(fh) # check args if isinstance(data_shape, int): data_shape = (3, data_shape, data_shape) assert len(data_shape) == 3 and data_shape[0] == 3 if prefix.endswith('_'): prefix += '_' + str(data_shape[1]) if isinstance(mean_pixels, (int, float)): mean_pixels = [mean_pixels, mean_pixels, mean_pixels] assert len(mean_pixels) == 3, "must provide all RGB mean values" train_iter = DetRecordIter(train_path, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=label_pad_width, path_imglist=train_list, **cfg.train) if val_path: val_iter = DetRecordIter(val_path, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=label_pad_width, path_imglist=val_list, **cfg.valid) else: val_iter = None # load symbol net = get_symbol_train(network, data_shape[1], num_classes=num_classes, nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk, minimum_negative_samples=min_neg_samples) # define layers with fixed weight/bias if freeze_layer_pattern.strip(): re_prog = re.compile(freeze_layer_pattern) fixed_param_names = [ name for name in net.list_arguments() if re_prog.match(name) ] else: fixed_param_names = None # load pretrained or resume from previous state ctx_str = '(' + ','.join([str(c) for c in ctx]) + ')' if resume > 0: logger.info("Resume training with {} from epoch {}".format( ctx_str, resume)) _, args, auxs = mx.model.load_checkpoint(prefix, resume) begin_epoch = resume elif finetune > 0: logger.info("Start finetuning with {} from epoch {}".format( ctx_str, finetune)) # _, args, auxs = mx.model.load_checkpoint(prefix, finetune) # Gang Chen changed _, args, auxs = mx.model.load_checkpoint(pretrained, finetune) begin_epoch = finetune # check what layers mismatch with the loaded parameters exe = net.simple_bind(mx.cpu(), data=(1, 3, 300, 300), label=(1, 1, 5), grad_req='null') arg_dict = exe.arg_dict fixed_param_names = [] for k, v in arg_dict.items(): if k in args: if v.shape != args[k].shape: del args[k] logging.info("Removed %s" % k) else: if not 'pred' in k: #Gang Chen #fixed_param_names.append(k) pass elif pretrained and not finetune: logger.info("Start training with {} from pretrained model {}".format( ctx_str, pretrained)) _, args, auxs = mx.model.load_checkpoint(pretrained, epoch) args = convert_pretrained(pretrained, args) else: logger.info("Experimental: start training from scratch with {}".format( ctx_str)) args = None auxs = None fixed_param_names = None # helper information if fixed_param_names: logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']') # visualize net - both train and test net_visualization(net=net, network=network, data_shape=data_shape[2], output_dir=os.path.dirname(prefix), train=True) net_visualization(net=None, network=network, data_shape=data_shape[2], output_dir=os.path.dirname(prefix), train=False, num_classes=num_classes) # init training module mod = mx.mod.Module(net, label_names=('label', ), logger=logger, context=ctx, fixed_param_names=fixed_param_names) batch_end_callback = [] eval_end_callback = [] epoch_end_callback = [ mx.callback.do_checkpoint(prefix, period=checkpoint_period) ] # add logging to tensorboard if tensorboard: tensorboard_dir = os.path.join(os.path.dirname(prefix), 'logs') if not os.path.exists(tensorboard_dir): os.makedirs(os.path.join(tensorboard_dir, 'train', 'scalar')) os.makedirs(os.path.join(tensorboard_dir, 'train', 'dist')) os.makedirs(os.path.join(tensorboard_dir, 'val', 'roc')) os.makedirs(os.path.join(tensorboard_dir, 'val', 'scalar')) os.makedirs(os.path.join(tensorboard_dir, 'val', 'images')) batch_end_callback.append( ParseLogCallback( dist_logging_dir=os.path.join(tensorboard_dir, 'train', 'dist'), scalar_logging_dir=os.path.join(tensorboard_dir, 'train', 'scalar'), logfile_path=log_file_path, batch_size=batch_size, iter_monitor=iter_monitor, frequent=frequent)) eval_end_callback.append( mx.contrib.tensorboard.LogMetricsCallback( os.path.join(tensorboard_dir, 'val/scalar'), 'ssd')) eval_end_callback.append( LogROCCallback(logging_dir=os.path.join(tensorboard_dir, 'val/roc'), roc_path=os.path.join(os.path.dirname(prefix), 'roc'), class_names=class_names)) eval_end_callback.append( LogDetectionsCallback( logging_dir=os.path.join(tensorboard_dir, 'val/images'), images_path=os.path.join(os.path.dirname(prefix), 'images'), class_names=class_names, batch_size=batch_size, mean_pixels=mean_pixels)) # this callback should be the last in a serie of batch_callbacks # since it is resetting the metric evaluation every $frequent batches batch_end_callback.append( mx.callback.Speedometer(train_iter.batch_size, frequent=frequent)) learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio, num_example, batch_size, begin_epoch) # add possibility for different optimizer opt, opt_params = get_optimizer_params(optimizer=optimizer, learning_rate=learning_rate, momentum=momentum, weight_decay=weight_decay, lr_scheduler=lr_scheduler, ctx=ctx, logger=logger) # TODO monitor the gradient flow as in 'https://github.com/dmlc/tensorboard/blob/master/docs/tutorial/understanding-vanish-gradient.ipynb' monitor = mx.mon.Monitor( iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None # run fit net, every n epochs we run evaluation network to get mAP if voc07_metric: valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3, roc_output_path=os.path.join( os.path.dirname(prefix), 'roc')) else: valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3, roc_output_path=os.path.join( os.path.dirname(prefix), 'roc')) mod.fit(train_iter, val_iter, eval_metric=MultiBoxMetric(), validation_metric=valid_metric, batch_end_callback=batch_end_callback, eval_end_callback=eval_end_callback, epoch_end_callback=epoch_end_callback, optimizer=opt, optimizer_params=opt_params, begin_epoch=begin_epoch, num_epoch=end_epoch, initializer=mx.init.Xavier(), arg_params=args, aux_params=auxs, allow_missing=True, monitor=monitor)
def evaluate_net(net, path_imgrec, num_classes, num_batch, mean_pixels, data_shape, model_prefix, epoch, ctx=mx.cpu(), batch_size=32, path_imglist="", nms_thresh=0.45, force_nms=False, ovp_thresh=0.5, use_difficult=False, class_names=None, voc07_metric=False, lite=False): """ evalute network given validation record file Parameters: ---------- net : str or None Network name or use None to load from json without modifying path_imgrec : str path to the record validation file path_imglist : str path to the list file to replace labels in record file, optional num_classes : int number of classes, not including background mean_pixels : tuple (mean_r, mean_g, mean_b) data_shape : tuple or int (3, height, width) or height/width model_prefix : str model prefix of saved checkpoint epoch : int load model epoch ctx : mx.ctx mx.gpu() or mx.cpu() batch_size : int validation batch size nms_thresh : float non-maximum suppression threshold force_nms : boolean whether suppress different class objects ovp_thresh : float AP overlap threshold for true/false postives use_difficult : boolean whether to use difficult objects in evaluation if applicable class_names : comma separated str class names in string, must correspond to num_classes if set voc07_metric : boolean whether to use 11-point evluation as in VOC07 competition """ # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) # args if isinstance(data_shape, int): data_shape = (3, data_shape, data_shape) assert len(data_shape) == 3 and data_shape[0] == 3 model_prefix += '_' + str(data_shape[1]) + '_' + str(data_shape[2]) # iterator eval_iter = DetRecordIter(path_imgrec, batch_size, data_shape, mean_pixels=mean_pixels, path_imglist=path_imglist, **cfg.valid) # model params load_net, args, auxs = mx.model.load_checkpoint(model_prefix, epoch) # network if net is None: net = load_net else: net = get_symbol(net, data_shape, num_classes=num_classes, nms_thresh=nms_thresh, force_suppress=force_nms, lite=lite) if not 'label' in net.list_arguments(): label = mx.sym.Variable(name='label') net = mx.sym.Group([net, label]) # init module mod = mx.mod.Module(net, label_names=('label', ), logger=logger, context=ctx, fixed_param_names=net.list_arguments()) mod.bind(data_shapes=eval_iter.provide_data, label_shapes=eval_iter.provide_label) mod.set_params(args, auxs, allow_missing=False, force_init=True) # run evaluation if voc07_metric: metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names) else: metric = MApMetric(ovp_thresh, use_difficult, class_names) num = num_batch * batch_size data = [ mx.random.uniform(-1.0, 1.0, shape=shape, ctx=ctx) for _, shape in mod.data_shapes ] batch = mx.io.DataBatch(data, []) # empty label dry_run = 5 # use 5 iterations to warm up for i in range(dry_run): mod.forward(batch, is_train=False) for output in mod.get_outputs(): output.wait_to_read() tic = time.time() results = mod.score(eval_iter, metric, num_batch=None, batch_end_callback=mx.callback.Speedometer( batch_size, frequent=10, auto_reset=False)) speed = num / (time.time() - tic) if logger is not None: logger.info('Finished inference with %d images' % num) logger.info('Finished with %f images per second', speed) for k, v in results: print("{}: {}".format(k, v))
def train_net(net, train_path, num_classes, batch_size, data_shape, mean_pixels, resume, finetune, pretrained, epoch, prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate, momentum, weight_decay, lr_refactor_step, lr_refactor_ratio, freeze_layer_pattern='', shape_range=(320, 512), random_shape_step=0, random_shape_epoch=10, num_example=10000, label_pad_width=350, nms_thresh=0.45, force_nms=False, ovp_thresh=0.5, use_difficult=False, class_names=None, voc07_metric=False, nms_topk=400, force_suppress=False, train_list="", val_path="", val_list="", iter_monitor=0, monitor_pattern=".*", log_file=None): """ Wrapper for training phase. Parameters: ---------- net : str symbol name for the network structure train_path : str record file path for training num_classes : int number of object classes, not including background batch_size : int training batch-size data_shape : int or tuple width/height as integer or (3, height, width) tuple mean_pixels : tuple of floats mean pixel values for red, green and blue resume : int resume from previous checkpoint if > 0 finetune : int fine-tune from previous checkpoint if > 0 pretrained : str prefix of pretrained model, including path epoch : int load epoch of either resume/finetune/pretrained model prefix : str prefix for saving checkpoints ctx : [mx.cpu()] or [mx.gpu(x)] list of mxnet contexts begin_epoch : int starting epoch for training, should be 0 if not otherwise specified end_epoch : int end epoch of training frequent : int frequency to print out training status learning_rate : float training learning rate momentum : float trainig momentum weight_decay : float training weight decay param lr_refactor_ratio : float multiplier for reducing learning rate lr_refactor_step : comma separated integers at which epoch to rescale learning rate, e.g. '30, 60, 90' freeze_layer_pattern : str regex pattern for layers need to be fixed shape_range : tuple of (min, max) random data shape range random_shape_step : int step size for random data shape, defined by network, 0 to disable random_step_epoch : int number of epoch before next random shape num_example : int number of training images label_pad_width : int force padding training and validation labels to sync their label widths nms_thresh : float non-maximum suppression threshold for validation force_nms : boolean suppress overlaped objects from different classes train_list : str list file path for training, this will replace the embeded labels in record val_path : str record file path for validation val_list : str list file path for validation, this will replace the embeded labels in record iter_monitor : int monitor internal stats in networks if > 0, specified by monitor_pattern monitor_pattern : str regex pattern for monitoring network stats log_file : str log to file if enabled """ # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) if log_file: fh = logging.FileHandler(log_file) logger.addHandler(fh) # check args if isinstance(data_shape, int): data_shape = (3, data_shape, data_shape) assert len(data_shape) == 3 and data_shape[0] == 3 prefix += '_' + net.strip('_yolo') + '_' + str(data_shape[1]) if isinstance(mean_pixels, (int, float)): mean_pixels = [mean_pixels, mean_pixels, mean_pixels] assert len(mean_pixels) == 3, "must provide all RGB mean values" # load symbol sys.path.append(os.path.join(cfg.ROOT_DIR, 'symbol')) symbol_module = importlib.import_module("symbol_" + net) net = symbol_module.get_symbol(num_classes, nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk) # define layers with fixed weight/bias if freeze_layer_pattern.strip(): re_prog = re.compile(freeze_layer_pattern) fixed_param_names = [ name for name in net.list_arguments() if re_prog.match(name) ] else: fixed_param_names = None # load pretrained or resume from previous state ctx_str = '(' + ','.join([str(c) for c in ctx]) + ')' allow_missing = True if resume > 0: logger.info("Resume training with {} from epoch {}".format( ctx_str, resume)) _, args, auxs = mx.model.load_checkpoint(prefix, resume) begin_epoch = resume allow_missing = False elif finetune > 0: logger.info("Start finetuning with {} from epoch {}".format( ctx_str, finetune)) _, args, auxs = mx.model.load_checkpoint(prefix, finetune) begin_epoch = finetune # the prediction convolution layers name starts with relu, so it's fine fixed_param_names = [name for name in net.list_arguments() \ if name.startswith('conv')] elif pretrained: logger.info("Start training with {} from pretrained model {}".format( ctx_str, pretrained)) _, args, auxs = mx.model.load_checkpoint(pretrained, epoch) args = convert_pretrained(pretrained, args) else: logger.info("Experimental: start training from scratch with {}".format( ctx_str)) args = None auxs = None fixed_param_names = None # helper information if fixed_param_names: logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']') # fit parameters batch_end_callback = mx.callback.Speedometer(batch_size, frequent=frequent) epoch_end_callback = mx.callback.do_checkpoint(prefix) monitor = mx.mon.Monitor( iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None # run fit net, every n epochs we run evaluation network to get mAP if voc07_metric: valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=0) else: valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=0) # init training module mod = mx.mod.Module(net, label_names=('yolo_output_label', ), logger=logger, context=ctx, fixed_param_names=fixed_param_names) random_shape_step = int(random_shape_step) if random_shape_step > 0: fit_begins = list(range(begin_epoch, end_epoch, random_shape_epoch)) fit_ends = fit_begins[1:] + [end_epoch] assert (len(shape_range) == 2) data_shapes = [(3, x * random_shape_step, x * random_shape_step) \ for x in range(shape_range[0] // random_shape_step, shape_range[1] // random_shape_step + 1)] logger.info("Candidate random shapes:" + str(data_shapes)) else: fit_begins = [begin_epoch] fit_ends = [end_epoch] data_shapes = [data_shape] for begin, end in zip(fit_begins, fit_ends): if len(data_shapes) == 1: data_shape = data_shapes[0] else: data_shape = data_shapes[random.randint(0, len(data_shapes) - 1)] logger.info("Setting random data shape: " + str(data_shape)) train_iter = DetRecordIter(train_path, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=label_pad_width, path_imglist=train_list, **cfg.train) if val_path: val_iter = DetRecordIter(val_path, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=label_pad_width, path_imglist=val_list, **cfg.valid) else: val_iter = None learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio, num_example, batch_size, begin_epoch) optimizer_params = { 'learning_rate': learning_rate, 'momentum': momentum, 'wd': weight_decay, 'lr_scheduler': lr_scheduler, 'clip_gradient': 10, 'rescale_grad': 1.0 } mod.fit(train_iter, val_iter, eval_metric=MultiBoxMetric(), validation_metric=valid_metric, batch_end_callback=batch_end_callback, epoch_end_callback=epoch_end_callback, optimizer='sgd', optimizer_params=optimizer_params, begin_epoch=begin, num_epoch=end, initializer=mx.init.Xavier(), arg_params=args, aux_params=auxs, allow_missing=allow_missing, monitor=monitor, force_rebind=True, force_init=True) args, auxs = mod.get_params() allow_missing = False
def train_net(net, train_path, num_classes, batch_size, data_shape, mean_pixels, resume, finetune, pretrained, epoch, prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate, momentum, weight_decay, lr_refactor_step, lr_refactor_ratio, freeze_layer_pattern='', num_example=10000, label_pad_width=350, nms_thresh=0.45, force_nms=False, ovp_thresh=0.5, use_difficult=False, class_names=None, voc07_metric=False, nms_topk=2000, force_suppress=False, train_list="", val_path="", val_list="", iter_monitor=0, monitor_pattern=".*", log_file=None): """ Wrapper for training phase. Parameters: ---------- net : str symbol name for the network structure train_path : str record file path for training num_classes : int number of object classes, not including background batch_size : int training batch-size data_shape : int or tuple width/height as integer or (3, height, width) tuple mean_pixels : tuple of floats mean pixel values for red, green and blue resume : int resume from previous checkpoint if > 0 finetune : int fine-tune from previous checkpoint if > 0 pretrained : str prefix of pretrained model, including path epoch : int load epoch of either resume/finetune/pretrained model prefix : str prefix for saving checkpoints ctx : [mx.cpu()] or [mx.gpu(x)] list of mxnet contexts begin_epoch : int starting epoch for training, should be 0 if not otherwise specified end_epoch : int end epoch of training frequent : int frequency to print out training status learning_rate : float training learning rate momentum : float trainig momentum weight_decay : float training weight decay param lr_refactor_ratio : float multiplier for reducing learning rate lr_refactor_step : comma separated integers at which epoch to rescale learning rate, e.g. '30, 60, 90' freeze_layer_pattern : str regex pattern for layers need to be fixed num_example : int number of training images label_pad_width : int force padding training and validation labels to sync their label widths nms_thresh : float non-maximum suppression threshold for validation force_nms : boolean suppress overlaped objects from different classes train_list : str list file path for training, this will replace the embeded labels in record val_path : str record file path for validation val_list : str list file path for validation, this will replace the embeded labels in record iter_monitor : int monitor internal stats in networks if > 0, specified by monitor_pattern monitor_pattern : str regex pattern for monitoring network stats log_file : str log to file if enabled """ # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) # set a log if log_file: # crate a fileHandler fh = logging.FileHandler(log_file) logger.addHandler(fh) # check args if isinstance(data_shape, int): data_shape = (3, data_shape, data_shape) assert len(data_shape) == 3 and data_shape[0] == 3 prefix += '_' + net + '_' + str(data_shape[1]) # check the mean_pixels is list if isinstance(mean_pixels, (int, float)): mean_pixels = [mean_pixels, mean_pixels, mean_pixels] assert len(mean_pixels) == 3, "must provide all RGB mean values" train_iter = DetRecordIter(train_path, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=label_pad_width, path_imglist=train_list, **cfg.train) # for c in range(12840): # batch = train_iter.next() # data=batch.data[0] # label=batch.label[0] # from matplotlib import pyplot as plt # import numpy as np # import cv2 # for i in range(2): # plt.subplot(1,2,i+1) # img = np.array(data[i].asnumpy().transpose(1,2,0).copy(), np.uint8) # box = label[i].asnumpy() # bbox = [] # print 'The', i, 'th image' # for j in range(box.shape[0]): # if box[j][0] == -1: # break # else: # bbox.append(box[j][1:5]) # for k in range(len(bbox)): # xmin = (bbox[k][0] * img.shape[0]).astype(np.int16) # ymin = (bbox[k][1] * img.shape[0]).astype(np.int16) # xmax = (bbox[k][2] * img.shape[0]).astype(np.int16) # ymax = (bbox[k][3] * img.shape[0]).astype(np.int16) # cv2.rectangle(img, (xmin,ymin), (xmax,ymax), (255,0,0),4) # # print 'xmin', xmin, 'ymin', ymin, 'xmax', xmax, 'ymax', ymax # plt.imshow(img) # plt.show() # #path = 'crop_image/'+ str(c) + '.jpg' # #plt.savefig(path) # print batch if val_path: val_iter = DetRecordIter(val_path, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=label_pad_width, path_imglist=val_list, **cfg.valid) else: val_iter = None # load symbol net = get_symbol_train(net, data_shape[1], num_classes=num_classes, nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk) # viz = mx.viz.plot_network(net) # viz.view() # define layers with fixed weight/bias if freeze_layer_pattern.strip(): re_prog = re.compile(freeze_layer_pattern) fixed_param_names = [ name for name in net.list_arguments() if re_prog.match(name) ] else: fixed_param_names = None # load pretrained or resume from previous state ctx_str = '(' + ','.join([str(c) for c in ctx]) + ')' if resume > 0: logger.info("Resume training with {} from epoch {}".format( ctx_str, resume)) _, args, auxs = mx.model.load_checkpoint(prefix, resume) begin_epoch = resume elif finetune > 0: logger.info("Start finetuning with {} from epoch {}".format( ctx_str, finetune)) _, args, auxs = mx.model.load_checkpoint(prefix, finetune) begin_epoch = finetune # the prediction convolution layers name starts with relu, so it's fine fixed_param_names = [name for name in net.list_arguments() \ if name.startswith('conv')] elif pretrained: logger.info("Start training with {} from pretrained model {}".format( ctx_str, pretrained)) fixed_param_names = None _, args, auxs = mx.model.load_checkpoint(pretrained, epoch) args = convert_pretrained(pretrained, args) else: logger.info("Experimental: start training from scratch with {}".format( ctx_str)) args = None auxs = None fixed_param_names = None # helper information if fixed_param_names: logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']') # init training module mod = mx.mod.Module(net, label_names=('label', ), logger=logger, context=ctx, fixed_param_names=fixed_param_names) # fit parameters batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent) epoch_end_callback = mx.callback.do_checkpoint(prefix) learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio, num_example, batch_size, begin_epoch) optimizer_params = { 'learning_rate': learning_rate, 'momentum': momentum, 'wd': weight_decay, 'lr_scheduler': lr_scheduler, 'clip_gradient': None, 'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0 } monitor = mx.mon.Monitor( iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None # run fit net, every n epochs we run evaluation network to get mAP if voc07_metric: valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3) else: valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3) mod.fit( train_data=train_iter, #train_iter, eval_data=val_iter, eval_metric=MultiBoxMetric(), validation_metric=valid_metric, batch_end_callback=batch_end_callback, epoch_end_callback=epoch_end_callback, optimizer='sgd', optimizer_params=optimizer_params, begin_epoch=begin_epoch, num_epoch=end_epoch, initializer=mx.init.Xavier(), arg_params=args, aux_params=auxs, allow_missing=True, monitor=monitor)