def net_visualization(network=None, num_classes=None, data_shape=None, train=None, output_dir=None, print_net=False, net=None): # if you specify your net, this means that you are calling this function from somewhere else.. if net is None: if not train: net = symbol_factory.get_symbol(network, data_shape, num_classes=num_classes) else: net = symbol_factory.get_symbol_train(network, data_shape, num_classes=num_classes) if not train: a = mx.viz.plot_network(net, shape={"data": (1, 3, data_shape, data_shape)}, \ node_attrs={"shape": 'rect', "fixedsize": 'false'}) filename = "ssd_" + network + '_' + str(data_shape) + '_' + 'test' else: a = mx.viz.plot_network(net, shape=None, \ node_attrs={"shape": 'rect', "fixedsize": 'false'}) filename = "ssd_" + network + '_' + 'train' a.render(os.path.join(output_dir, filename)) if print_net: print(net.tojson())
def train_net(net, train_path, num_classes, batch_size, data_shape, mean_pixels, resume, finetune, pretrained, epoch, prefix, ctx, begin_epoch, end_epoch, solver, frequent, learning_rate, momentum, weight_decay, lr_refactor_step, lr_refactor_ratio, freeze_layer_pattern='', num_example=10000, label_pad_width=350, nms_thresh=0.45, force_nms=False, ovp_thresh=0.5, use_difficult=False, class_names=None, voc07_metric=False, nms_topk=400, force_suppress=False, train_list="", val_path="", val_list="", iter_monitor=0, monitor_pattern=".*", log_file=None, lite=False, kv_store=None): """ Wrapper for training phase. Parameters: ---------- net : str symbol name for the network structure train_path : str record file path for training num_classes : int number of object classes, not including background batch_size : int training batch-size data_shape : int or tuple width/height as integer or (3, height, width) tuple mean_pixels : tuple of floats mean pixel values for red, green and blue resume : int resume from previous checkpoint if > 0 finetune : int fine-tune from previous checkpoint if > 0 pretrained : str prefix of pretrained model, including path epoch : int load epoch of either resume/finetune/pretrained model prefix : str prefix for saving checkpoints ctx : [mx.cpu()] or [mx.gpu(x)] list of mxnet contexts begin_epoch : int starting epoch for training, should be 0 if not otherwise specified end_epoch : int end epoch of training frequent : int frequency to print out training status learning_rate : float training learning rate momentum : float trainig momentum weight_decay : float training weight decay param lr_refactor_ratio : float multiplier for reducing learning rate lr_refactor_step : comma separated integers at which epoch to rescale learning rate, e.g. '30, 60, 90' freeze_layer_pattern : str regex pattern for layers need to be fixed num_example : int number of training images label_pad_width : int force padding training and validation labels to sync their label widths nms_thresh : float non-maximum suppression threshold for validation force_nms : boolean suppress overlaped objects from different classes train_list : str list file path for training, this will replace the embeded labels in record val_path : str record file path for validation val_list : str list file path for validation, this will replace the embeded labels in record iter_monitor : int monitor internal stats in networks if > 0, specified by monitor_pattern monitor_pattern : str regex pattern for monitoring network stats log_file : str log to file if enabled """ # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) if log_file: fh = logging.FileHandler(log_file) logger.addHandler(fh) # check args if isinstance(data_shape, int): data_shape = (3, data_shape, data_shape) assert len(data_shape) == 3 and data_shape[0] == 3 if lite: prefix += 'lite' prefix += '_' + net + '_' + str(data_shape[1]) + '_' + str(data_shape[2]) print(prefix) if isinstance(mean_pixels, (int, float)): mean_pixels = [mean_pixels, mean_pixels, mean_pixels] assert len(mean_pixels) == 3, "must provide all RGB mean values" train_iter = DetRecordIter(train_path, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=label_pad_width, path_imglist=train_list, **cfg.train) if val_path: val_iter = DetRecordIter(val_path, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=label_pad_width, path_imglist=val_list, **cfg.valid) else: val_iter = None # load symbol net = get_symbol_train(net, data_shape, num_classes=num_classes, nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk, lite=lite) # define layers with fixed weight/bias if freeze_layer_pattern.strip(): re_prog = re.compile(freeze_layer_pattern) fixed_param_names = [ name for name in net.list_arguments() if re_prog.match(name) ] else: fixed_param_names = None # load pretrained or resume from previous state ctx_str = '(' + ','.join([str(c) for c in ctx]) + ')' if resume > 0: logger.info("Resume training with {} from epoch {}".format( ctx_str, resume)) _, args, auxs = mx.model.load_checkpoint(prefix, resume) begin_epoch = resume elif finetune > 0: logger.info("Start finetuning with {} from epoch {}".format( ctx_str, finetune)) _, args, auxs = mx.model.load_checkpoint(prefix, finetune) begin_epoch = finetune # the prediction convolution layers name starts with relu, so it's fine fixed_param_names = [name for name in net.list_arguments() \ if name.startswith('conv')] elif pretrained: logger.info("Start training with {} from pretrained model {}".format( ctx_str, pretrained)) _, args, auxs = mx.model.load_checkpoint(pretrained, epoch) args = convert_pretrained(pretrained, args) else: logger.info("Experimental: start training from scratch with {}".format( ctx_str)) args = None auxs = None fixed_param_names = None # helper information if fixed_param_names: logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']') # init training module mod = mx.mod.Module(net, label_names=('label', ), logger=logger, context=ctx, fixed_param_names=fixed_param_names) # fit parameters batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent) epoch_end_callback = mx.callback.do_checkpoint(prefix, end_epoch / 2) learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio, num_example, batch_size, begin_epoch) if solver == 'sgd': optimizer_params = { 'learning_rate': learning_rate, 'momentum': momentum, 'wd': weight_decay, 'lr_scheduler': lr_scheduler, 'clip_gradient': None, 'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0 } elif solver == 'rmsprop': optimizer_params = { 'learning_rate': learning_rate, 'gamma1': 0.5, 'wd': weight_decay, 'lr_scheduler': lr_scheduler, 'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0 } monitor = mx.mon.Monitor( iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None # run fit net, every n epochs we run evaluation network to get mAP if voc07_metric: valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3) else: valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3) # create kvstore when there are gpus kv = mx.kvstore.create(kv_store) if kv_store else None mod.fit(train_iter, val_iter, eval_metric=MultiBoxMetric(), validation_metric=valid_metric, batch_end_callback=batch_end_callback, epoch_end_callback=epoch_end_callback, optimizer=solver, optimizer_params=optimizer_params, begin_epoch=begin_epoch, num_epoch=end_epoch, initializer=mx.init.Xavier(), arg_params=args, aux_params=auxs, allow_missing=True, monitor=monitor, kvstore=kv)
def train_net(network, train_path, num_classes, batch_size, data_shape, mean_pixels, resume, finetune, pretrained, epoch, prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate, momentum, weight_decay, lr_refactor_step, lr_refactor_ratio, alpha_bb8=1.0, freeze_layer_pattern='', num_example=5717, label_pad_width=350, nms_thresh=0.45, force_nms=False, ovp_thresh=0.5, use_difficult=False, class_names=None, voc07_metric=False, nms_topk=400, force_suppress=False, train_list="", val_path="", val_list="", iter_monitor=0, monitor_pattern=".*", log_file=None, optimizer='sgd', tensorboard=False, checkpoint_period=5, min_neg_samples=0): """ Wrapper for training phase. Parameters: ---------- net : str symbol name for the network structure train_path : str record file path for training num_classes : int number of object classes, not including background batch_size : int training batch-size data_shape : int or tuple width/height as integer or (3, height, width) tuple mean_pixels : tuple of floats mean pixel values for red, green and blue resume : int resume from previous checkpoint if > 0 finetune : int fine-tune from previous checkpoint if > 0 pretrained : str prefix of pretrained model, including path epoch : int load epoch of either resume/finetune/pretrained model prefix : str prefix for saving checkpoints ctx : [mx.cpu()] or [mx.gpu(x)] list of mxnet contexts begin_epoch : int starting epoch for training, should be 0 if not otherwise specified end_epoch : int end epoch of training frequent : int frequency to print out training status optimizer : str usage of different optimizers, other then default sgd learning_rate : float training learning rate momentum : float trainig momentum weight_decay : float training weight decay param lr_refactor_ratio : float multiplier for reducing learning rate lr_refactor_step : comma separated integers at which epoch to rescale learning rate, e.g. '30, 60, 90' freeze_layer_pattern : str regex pattern for layers need to be fixed num_example : int number of training images label_pad_width : int force padding training and validation labels to sync their label widths nms_thresh : float non-maximum suppression threshold for validation force_nms : boolean suppress overlaped objects from different classes train_list : str list file path for training, this will replace the embeded labels in record val_path : str record file path for validation val_list : str list file path for validation, this will replace the embeded labels in record iter_monitor : int monitor internal stats in networks if > 0, specified by monitor_pattern monitor_pattern : str regex pattern for monitoring network stats log_file : str log to file if enabled tensorboard : bool record logs into tensorboard min_neg_samples : int always have some negative examples, no matter how many positive there are. this is useful when training on images with no ground-truth. checkpoint_period : int a checkpoint will be saved every "checkpoint_period" epochs """ # check actual number of train_images if os.path.exists(train_path.replace('rec', 'idx')): with open(train_path.replace('rec', 'idx'), 'r') as f: txt = f.readlines() num_example = len(txt) # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) if log_file: log_file_path = os.path.join(os.path.dirname(prefix), log_file) if not os.path.exists(os.path.dirname(log_file_path)): os.makedirs(os.path.dirname(log_file_path)) fh = logging.FileHandler(log_file_path) logger.addHandler(fh) # check args if isinstance(data_shape, int): data_shape = (3, data_shape, data_shape) assert len(data_shape) == 3 and data_shape[0] == 3 if prefix.endswith('_'): prefix += '_' + str(data_shape[1]) if isinstance(mean_pixels, (int, float)): mean_pixels = [mean_pixels, mean_pixels, mean_pixels] assert len(mean_pixels) == 3, "must provide all RGB mean values" train_iter = DetRecordIter(train_path, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=label_pad_width, path_imglist=train_list, **cfg.train) label = train_iter._batch.label[0].asnumpy() if val_path: val_iter = DetRecordIter(val_path, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=label_pad_width, path_imglist=val_list, **cfg.valid) val_label = val_iter._batch.label[0].asnumpy() else: val_iter = None # load symbol net = get_symbol_train(network, data_shape[1], alpha_bb8, num_classes=num_classes, nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk, minimum_negative_samples=min_neg_samples) # define layers with fixed weight/bias if freeze_layer_pattern.strip(): re_prog = re.compile(freeze_layer_pattern) fixed_param_names = [ name for name in net.list_arguments() if re_prog.match(name) ] else: fixed_param_names = None # load pretrained or resume from previous state ctx_str = '(' + ','.join([str(c) for c in ctx]) + ')' if resume > 0: logger.info("Resume training with {} from epoch {}".format( ctx_str, resume)) _, args, auxs = mx.model.load_checkpoint(prefix, resume) begin_epoch = resume elif finetune > 0: logger.info("Start finetuning with {} from epoch {}".format( ctx_str, finetune)) _, args, auxs = mx.model.load_checkpoint(prefix, finetune) begin_epoch = finetune # check what layers mismatch with the loaded parameters exe = net.simple_bind(mx.cpu(), data=(1, 3, 300, 300), label=(1, 1, 5), grad_req='null') arg_dict = exe.arg_dict fixed_param_names = [] for k, v in arg_dict.items(): if k in args: if v.shape != args[k].shape: del args[k] logging.info("Removed %s" % k) else: if not 'pred' in k: fixed_param_names.append(k) elif pretrained: logger.info("Start training with {} from pretrained model {}".format( ctx_str, pretrained)) _, args, auxs = mx.model.load_checkpoint(pretrained, epoch) args = convert_pretrained(pretrained, args) else: logger.info("Experimental: start training from scratch with {}".format( ctx_str)) args = None auxs = None fixed_param_names = None # helper information if fixed_param_names: logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']') # visualize net - both train and test net_visualization(net=net, network=network, data_shape=data_shape[2], output_dir=os.path.dirname(prefix), train=True) # net_visualization(net=None, network=network, data_shape=data_shape[2], # output_dir=os.path.dirname(prefix), train=False, num_classes=num_classes) # init training module data_names = [k[0] for k in train_iter.provide_data] label_names = [k[0] for k in train_iter.provide_label] mod = mx.mod.Module(net, data_names=data_names, label_names=label_names, logger=logger, context=ctx, fixed_param_names=fixed_param_names) batch_end_callback = [] eval_end_callback = [] epoch_end_callback = [ mx.callback.do_checkpoint(prefix, period=checkpoint_period) ] # add logging to tensorboard if tensorboard: tensorboard_dir = os.path.join(os.path.dirname(prefix), 'logs') if not os.path.exists(tensorboard_dir): os.makedirs(os.path.join(tensorboard_dir, 'train', 'scalar')) os.makedirs(os.path.join(tensorboard_dir, 'train', 'dist')) os.makedirs(os.path.join(tensorboard_dir, 'val', 'roc')) os.makedirs(os.path.join(tensorboard_dir, 'val', 'scalar')) os.makedirs(os.path.join(tensorboard_dir, 'val', 'images')) batch_end_callback.append( ParseLogCallback( dist_logging_dir=os.path.join(tensorboard_dir, 'train', 'dist'), scalar_logging_dir=os.path.join(tensorboard_dir, 'train', 'scalar'), logfile_path=log_file_path, batch_size=batch_size, iter_monitor=iter_monitor, frequent=frequent)) eval_end_callback.append( LogMetricsCallback(os.path.join(tensorboard_dir, 'val/scalar'), 'ssd', global_step=0)) # eval_end_callback.append(LogROCCallback(logging_dir=os.path.join(tensorboard_dir, 'val/roc'), # roc_path=os.path.join(os.path.dirname(prefix), 'roc'), # class_names=class_names)) # eval_end_callback.append(LogDetectionsCallback(logging_dir=os.path.join(tensorboard_dir, 'val/images'), # images_path=os.path.join(os.path.dirname(prefix), 'images'), # class_names=class_names,batch_size=batch_size,mean_pixels=mean_pixels)) # this callback should be the last in a serie of batch_callbacks # since it is resetting the metric evaluation every $frequent batches batch_end_callback.append( mx.callback.Speedometer(train_iter.batch_size, frequent=frequent)) learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio, num_example, batch_size, begin_epoch) logger.info( "learning rate: {}, lr refactor step: {}, lr refactor ratio: {}, batch size: {}." .format(learning_rate, lr_refactor_step, lr_refactor_ratio, batch_size)) # add possibility for different optimizer opt, opt_params = get_optimizer_params(optimizer=optimizer, learning_rate=learning_rate, momentum=momentum, weight_decay=weight_decay, lr_scheduler=lr_scheduler, ctx=ctx, logger=logger) logger.info("Optimizer: {}".format(opt)) for k, v in opt_params.items(): if k == 'lr_scheduler': continue logger.info("{}: {}".format(k, v)) # TODO monitor the gradient flow as in 'https://github.com/dmlc/tensorboard/blob/master/docs/tutorial/understanding-vanish-gradient.ipynb' monitor = mx.mon.Monitor( iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None # run fit net, every n epochs we run evaluation network to get mAP if voc07_metric: valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=4, roc_output_path=os.path.join( os.path.dirname(prefix), 'roc')) else: valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=4, roc_output_path=os.path.join( os.path.dirname(prefix), 'roc')) mod.fit( train_iter, val_iter, eval_metric=MultiBoxMetric(), validation_metric=MultiBoxMetric( ), # use 'valid_metric' for calculate mAP batch_end_callback=batch_end_callback, eval_end_callback=eval_end_callback, epoch_end_callback=epoch_end_callback, optimizer=opt, optimizer_params=opt_params, begin_epoch=begin_epoch, num_epoch=end_epoch, initializer=mx.init.Xavier(), arg_params=args, aux_params=auxs, allow_missing=True, monitor=monitor)
def train_net_common(net, train_iter, val_iter, batch_size, data_shape, resume, finetune, pretrained, epoch, prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate, momentum, weight_decay, use_plateau, lr_refactor_step, lr_refactor_ratio, freeze_layer_pattern='', num_example=10000, label_pad_width=350, nms_thresh=0.45, force_suppress=False, ovp_thresh=0.5, use_difficult=False, class_names=None, optimizer_name='sgd', voc07_metric=False, nms_topk=400, iter_monitor=0, monitor_pattern=".*", logger=None): """ """ # check args prefix += '_' + net + '_' + str(data_shape[1]) # load symbol net_str = net net = get_symbol_train(net, data_shape[1], num_classes=len(class_names), nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk) # define layers with fixed weight/bias if freeze_layer_pattern.strip(): re_prog = re.compile(freeze_layer_pattern) fixed_param_names = [name for name in net.list_arguments() if re_prog.match(name)] else: fixed_param_names = None # load pretrained or resume from previous state ctx_str = '('+ ','.join([str(c) for c in ctx]) + ')' if resume > 0: logger.info("Resume training with {} from epoch {}" .format(ctx_str, resume)) _, args, auxs = mx.model.load_checkpoint(prefix, resume) begin_epoch = resume elif finetune > 0: logger.info("Start finetuning with {} from epoch {}" .format(ctx_str, finetune)) _, args, auxs = mx.model.load_checkpoint(prefix, finetune) begin_epoch = finetune # the prediction convolution layers name starts with relu, so it's fine fixed_param_names = [name for name in net.list_arguments() \ if name.startswith('conv')] elif pretrained: try: logger.info("Start training with {} from pretrained model {}" .format(ctx_str, pretrained)) _, args, auxs = mx.model.load_checkpoint(pretrained, epoch) args = convert_pretrained(pretrained, args) if net_str == 'ssd_pva': args, auxs = convert_pvanet(args, auxs) except: logger.info("Failed to load the pretrained model. Start from scratch.") args = None auxs = None fixed_param_names = None else: logger.info("Experimental: start training from scratch with {}" .format(ctx_str)) args = None auxs = None fixed_param_names = None # helper information if fixed_param_names: logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']') # init training module if not use_plateau: # focal loss does not go well with plateau mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx, fixed_param_names=fixed_param_names) else: mod = PlateauModule(net, label_names=('label',), logger=logger, context=ctx, fixed_param_names=fixed_param_names) # robust parameter setting mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label) mod = set_mod_params(mod, args, auxs, logger) # fit parameters batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent, auto_reset=True) epoch_end_callback = mx.callback.do_checkpoint(prefix) monitor = mx.mon.Monitor(iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None optimizer_params={'learning_rate': learning_rate, 'wd': weight_decay, 'clip_gradient': 4.0, 'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0 } if optimizer_name == 'sgd': optimizer_params['momentum'] = momentum # #7847 mod.init_optimizer(optimizer=optimizer_name, optimizer_params=optimizer_params, force_init=True) if not use_plateau: learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio, num_example, batch_size, begin_epoch) else: w_l1 = cfg.train['smoothl1_weight'] eval_weights = {'CrossEntropy': 1.0, 'SmoothL1': w_l1} plateau_lr = PlateauScheduler( \ patient_epochs=lr_refactor_step, factor=float(lr_refactor_ratio), eval_weights=eval_weights) plateau_metric = MultiBoxMetric(fn_stat='/home/hyunjoon/github/additions_mxnet/ssd/stat.txt') eval_metric = MultiBoxMetric() # run fit net, every n epochs we run evaluation network to get mAP if voc07_metric: valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=4) else: valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=4) if not use_plateau: mod.fit(train_iter, eval_data=val_iter, eval_metric=eval_metric, validation_metric=valid_metric, batch_end_callback=batch_end_callback, epoch_end_callback=epoch_end_callback, optimizer=optimizer_name, optimizer_params=optimizer_params, begin_epoch=begin_epoch, num_epoch=end_epoch, initializer=mx.init.Xavier(), arg_params=args, aux_params=auxs, allow_missing=True, monitor=monitor) else: mod.fit(train_iter, plateau_lr, plateau_metric=plateau_metric, fn_curr_model=prefix+'-1000.params', plateau_backtrace=False, eval_data=val_iter, eval_metric=eval_metric, validation_metric=valid_metric, validation_period=5, batch_end_callback=batch_end_callback, epoch_end_callback=epoch_end_callback, optimizer=optimizer_name, optimizer_params=optimizer_params, begin_epoch=begin_epoch, num_epoch=end_epoch, initializer=mx.init.Xavier(), arg_params=args, aux_params=auxs, allow_missing=True, monitor=monitor)
def train_net(network, train_path, num_classes, batch_size, data_shape, mean_pixels, resume, finetune, pretrained, epoch, prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate, momentum, weight_decay, lr_refactor_step, lr_refactor_ratio, freeze_layer_pattern='', num_example=10000, label_pad_width=350, nms_thresh=0.45, force_nms=False, ovp_thresh=0.5, use_difficult=False, class_names=None, voc07_metric=False, nms_topk=400, force_suppress=False, train_list="", val_path="", val_list="", iter_monitor=0, monitor_pattern=".*", log_file=None, optimizer='sgd', tensorboard=False, checkpoint_period=5, min_neg_samples=0): """ Wrapper for training phase. Parameters: ---------- net : str symbol name for the network structure train_path : str record file path for training num_classes : int number of object classes, not including background batch_size : int training batch-size data_shape : int or tuple width/height as integer or (3, height, width) tuple mean_pixels : tuple of floats mean pixel values for red, green and blue resume : int resume from previous checkpoint if > 0 finetune : int fine-tune from previous checkpoint if > 0 pretrained : str prefix of pretrained model, including path epoch : int load epoch of either resume/finetune/pretrained model prefix : str prefix for saving checkpoints ctx : [mx.cpu()] or [mx.gpu(x)] list of mxnet contexts begin_epoch : int starting epoch for training, should be 0 if not otherwise specified end_epoch : int end epoch of training frequent : int frequency to print out training status optimizer : str usage of different optimizers, other then default sgd learning_rate : float training learning rate momentum : float trainig momentum weight_decay : float training weight decay param lr_refactor_ratio : float multiplier for reducing learning rate lr_refactor_step : comma separated integers at which epoch to rescale learning rate, e.g. '30, 60, 90' freeze_layer_pattern : str regex pattern for layers need to be fixed num_example : int number of training images label_pad_width : int force padding training and validation labels to sync their label widths nms_thresh : float non-maximum suppression threshold for validation force_nms : boolean suppress overlaped objects from different classes train_list : str list file path for training, this will replace the embeded labels in record val_path : str record file path for validation val_list : str list file path for validation, this will replace the embeded labels in record iter_monitor : int monitor internal stats in networks if > 0, specified by monitor_pattern monitor_pattern : str regex pattern for monitoring network stats log_file : str log to file if enabled tensorboard : bool record logs into tensorboard min_neg_samples : int always have some negative examples, no matter how many positive there are. this is useful when training on images with no ground-truth. checkpoint_period : int a checkpoint will be saved every "checkpoint_period" epochs """ # check actual number of train_images if os.path.exists(train_path.replace('rec','idx')): with open(train_path.replace('rec','idx'), 'r') as f: txt = f.readlines() num_example = len(txt) # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) if log_file: log_file_path = os.path.join(os.path.dirname(prefix), log_file) if not os.path.exists(os.path.dirname(log_file_path)): os.makedirs(os.path.dirname(log_file_path)) fh = logging.FileHandler(log_file_path) logger.addHandler(fh) # check args if isinstance(data_shape, int): data_shape = (3, data_shape, data_shape) assert len(data_shape) == 3 and data_shape[0] == 3 if prefix.endswith('_'): prefix += '_' + str(data_shape[1]) if isinstance(mean_pixels, (int, float)): mean_pixels = [mean_pixels, mean_pixels, mean_pixels] assert len(mean_pixels) == 3, "must provide all RGB mean values" train_iter = DetRecordIter(train_path, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=label_pad_width, path_imglist=train_list, **cfg.train) if val_path: val_iter = DetRecordIter(val_path, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=label_pad_width, path_imglist=val_list, **cfg.valid) else: val_iter = None # load symbol net = get_symbol_train(network, data_shape[1], num_classes=num_classes, nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk, minimum_negative_samples=min_neg_samples) # define layers with fixed weight/bias if freeze_layer_pattern.strip(): re_prog = re.compile(freeze_layer_pattern) fixed_param_names = [name for name in net.list_arguments() if re_prog.match(name)] else: fixed_param_names = None # load pretrained or resume from previous state ctx_str = '(' + ','.join([str(c) for c in ctx]) + ')' if resume > 0: logger.info("Resume training with {} from epoch {}" .format(ctx_str, resume)) _, args, auxs = mx.model.load_checkpoint(prefix, resume) begin_epoch = resume elif finetune > 0: logger.info("Start finetuning with {} from epoch {}" .format(ctx_str, finetune)) _, args, auxs = mx.model.load_checkpoint(prefix, finetune) begin_epoch = finetune # check what layers mismatch with the loaded parameters exe = net.simple_bind(mx.cpu(), data=(1, 3, 300, 300), label=(1, 1, 5), grad_req='null') arg_dict = exe.arg_dict fixed_param_names = [] for k, v in arg_dict.items(): if k in args: if v.shape != args[k].shape: del args[k] logging.info("Removed %s" % k) else: if not 'pred' in k: fixed_param_names.append(k) elif pretrained: logger.info("Start training with {} from pretrained model {}" .format(ctx_str, pretrained)) _, args, auxs = mx.model.load_checkpoint(pretrained, epoch) args = convert_pretrained(pretrained, args) else: logger.info("Experimental: start training from scratch with {}" .format(ctx_str)) args = None auxs = None fixed_param_names = None # helper information if fixed_param_names: logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']') # visualize net - both train and test net_visualization(net=net, network=network,data_shape=data_shape[2], output_dir=os.path.dirname(prefix), train=True) net_visualization(net=None, network=network, data_shape=data_shape[2], output_dir=os.path.dirname(prefix), train=False, num_classes=num_classes) # init training module mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx, fixed_param_names=fixed_param_names) batch_end_callback = [] eval_end_callback = [] epoch_end_callback = [mx.callback.do_checkpoint(prefix, period=checkpoint_period)] # add logging to tensorboard if tensorboard: tensorboard_dir = os.path.join(os.path.dirname(prefix), 'logs') if not os.path.exists(tensorboard_dir): os.makedirs(os.path.join(tensorboard_dir, 'train', 'scalar')) os.makedirs(os.path.join(tensorboard_dir, 'train', 'dist')) os.makedirs(os.path.join(tensorboard_dir, 'val', 'roc')) os.makedirs(os.path.join(tensorboard_dir, 'val', 'scalar')) os.makedirs(os.path.join(tensorboard_dir, 'val', 'images')) batch_end_callback.append( ParseLogCallback(dist_logging_dir=os.path.join(tensorboard_dir, 'train', 'dist'), scalar_logging_dir=os.path.join(tensorboard_dir, 'train', 'scalar'), logfile_path=log_file_path, batch_size=batch_size, iter_monitor=iter_monitor, frequent=frequent)) eval_end_callback.append(mx.contrib.tensorboard.LogMetricsCallback( os.path.join(tensorboard_dir, 'val/scalar'), 'ssd')) eval_end_callback.append(LogROCCallback(logging_dir=os.path.join(tensorboard_dir, 'val/roc'), roc_path=os.path.join(os.path.dirname(prefix), 'roc'), class_names=class_names)) eval_end_callback.append(LogDetectionsCallback(logging_dir=os.path.join(tensorboard_dir, 'val/images'), images_path=os.path.join(os.path.dirname(prefix), 'images'), class_names=class_names,batch_size=batch_size,mean_pixels=mean_pixels)) # this callback should be the last in a serie of batch_callbacks # since it is resetting the metric evaluation every $frequent batches batch_end_callback.append(mx.callback.Speedometer(train_iter.batch_size, frequent=frequent)) learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio, num_example, batch_size, begin_epoch) # add possibility for different optimizer opt, opt_params = get_optimizer_params(optimizer=optimizer, learning_rate=learning_rate, momentum=momentum, weight_decay=weight_decay, lr_scheduler=lr_scheduler, ctx=ctx, logger=logger) # TODO monitor the gradient flow as in 'https://github.com/dmlc/tensorboard/blob/master/docs/tutorial/understanding-vanish-gradient.ipynb' monitor = mx.mon.Monitor(iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None # run fit net, every n epochs we run evaluation network to get mAP if voc07_metric: valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3, roc_output_path=os.path.join(os.path.dirname(prefix), 'roc')) else: valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3, roc_output_path=os.path.join(os.path.dirname(prefix), 'roc')) mod.fit(train_iter, val_iter, eval_metric=MultiBoxMetric(), validation_metric=valid_metric, batch_end_callback=batch_end_callback, eval_end_callback=eval_end_callback, epoch_end_callback=epoch_end_callback, optimizer=opt, optimizer_params=opt_params, begin_epoch=begin_epoch, num_epoch=end_epoch, initializer=mx.init.Xavier(), arg_params=args, aux_params=auxs, allow_missing=True, monitor=monitor)
def train_net(net, train_path, num_classes, batch_size, data_shape, mean_pixels, resume, finetune, pretrained, epoch, prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate, momentum, weight_decay, lr_refactor_step, lr_refactor_ratio, freeze_layer_pattern='', num_example=10000, label_pad_width=350, nms_thresh=0.45, force_nms=False, ovp_thresh=0.5, use_difficult=False, class_names=None, voc07_metric=False, nms_topk=400, force_suppress=False, train_list="", val_path="", val_list="", iter_monitor=0, monitor_pattern=".*", log_file=None, kv_store=None): """ Wrapper for training phase. Parameters: ---------- net : str symbol name for the network structure train_path : str record file path for training num_classes : int number of object classes, not including background batch_size : int training batch-size data_shape : int or tuple width/height as integer or (3, height, width) tuple mean_pixels : tuple of floats mean pixel values for red, green and blue resume : int resume from previous checkpoint if > 0 finetune : int fine-tune from previous checkpoint if > 0 pretrained : str prefix of pretrained model, including path epoch : int load epoch of either resume/finetune/pretrained model prefix : str prefix for saving checkpoints ctx : [mx.cpu()] or [mx.gpu(x)] list of mxnet contexts begin_epoch : int starting epoch for training, should be 0 if not otherwise specified end_epoch : int end epoch of training frequent : int frequency to print out training status learning_rate : float training learning rate momentum : float trainig momentum weight_decay : float training weight decay param lr_refactor_ratio : float multiplier for reducing learning rate lr_refactor_step : comma separated integers at which epoch to rescale learning rate, e.g. '30, 60, 90' freeze_layer_pattern : str regex pattern for layers need to be fixed num_example : int number of training images label_pad_width : int force padding training and validation labels to sync their label widths nms_thresh : float non-maximum suppression threshold for validation force_nms : boolean suppress overlaped objects from different classes train_list : str list file path for training, this will replace the embeded labels in record val_path : str record file path for validation val_list : str list file path for validation, this will replace the embeded labels in record iter_monitor : int monitor internal stats in networks if > 0, specified by monitor_pattern monitor_pattern : str regex pattern for monitoring network stats log_file : str log to file if enabled """ # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) if log_file: fh = logging.FileHandler(log_file) logger.addHandler(fh) # check args if isinstance(data_shape, int): data_shape = (3, data_shape, data_shape) assert len(data_shape) == 3 and data_shape[0] == 3 prefix += '_' + net + '_' + str(data_shape[1]) if isinstance(mean_pixels, (int, float)): mean_pixels = [mean_pixels, mean_pixels, mean_pixels] assert len(mean_pixels) == 3, "must provide all RGB mean values" train_iter = DetRecordIter(train_path, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=label_pad_width, path_imglist=train_list, **cfg.train) if val_path: val_iter = DetRecordIter(val_path, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=label_pad_width, path_imglist=val_list, **cfg.valid) else: val_iter = None # load symbol net = get_symbol_train(net, data_shape[1], num_classes=num_classes, nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk) # define layers with fixed weight/bias if freeze_layer_pattern.strip(): re_prog = re.compile(freeze_layer_pattern) fixed_param_names = [name for name in net.list_arguments() if re_prog.match(name)] else: fixed_param_names = None # load pretrained or resume from previous state ctx_str = '('+ ','.join([str(c) for c in ctx]) + ')' if resume > 0: logger.info("Resume training with {} from epoch {}" .format(ctx_str, resume)) _, args, auxs = mx.model.load_checkpoint(prefix, resume) begin_epoch = resume elif finetune > 0: logger.info("Start finetuning with {} from epoch {}" .format(ctx_str, finetune)) _, args, auxs = mx.model.load_checkpoint(prefix, finetune) begin_epoch = finetune # the prediction convolution layers name starts with relu, so it's fine fixed_param_names = [name for name in net.list_arguments() \ if name.startswith('conv')] elif pretrained: logger.info("Start training with {} from pretrained model {}" .format(ctx_str, pretrained)) _, args, auxs = mx.model.load_checkpoint(pretrained, epoch) args = convert_pretrained(pretrained, args) else: logger.info("Experimental: start training from scratch with {}" .format(ctx_str)) args = None auxs = None fixed_param_names = None # helper information if fixed_param_names: logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']') # init training module mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx, fixed_param_names=fixed_param_names) # fit parameters batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent) epoch_end_callback = mx.callback.do_checkpoint(prefix) learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio, num_example, batch_size, begin_epoch) optimizer_params={'learning_rate':learning_rate, 'momentum':momentum, 'wd':weight_decay, 'lr_scheduler':lr_scheduler, 'clip_gradient':None, 'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0 } monitor = mx.mon.Monitor(iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None # run fit net, every n epochs we run evaluation network to get mAP if voc07_metric: valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3) else: valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3) # create kvstore when there are gpus kv = mx.kvstore.create(kv_store) if kv_store else None mod.fit(train_iter, val_iter, eval_metric=MultiBoxMetric(), validation_metric=valid_metric, batch_end_callback=batch_end_callback, epoch_end_callback=epoch_end_callback, optimizer='sgd', optimizer_params=optimizer_params, begin_epoch=begin_epoch, num_epoch=end_epoch, initializer=mx.init.Xavier(), arg_params=args, aux_params=auxs, allow_missing=True, monitor=monitor, kvstore=kv)
parser.add_argument('--network', type=str, default='vgg16_reduced', help='the cnn to use') parser.add_argument('--num-classes', type=int, default=20, help='the number of classes') parser.add_argument('--data-shape', type=int, default=300, help='set image\'s shape') parser.add_argument('--train', action='store_true', default=False, help='show train net') args = parser.parse_args() if not args.train: net = symbol_factory.get_symbol(args.network, args.data_shape, num_classes=args.num_classes) a = mx.viz.plot_network(net, shape={"data":(1,3,args.data_shape,args.data_shape)}, \ node_attrs={"shape":'rect', "fixedsize":'false'}) a.render("ssd_" + args.network + '_' + str(args.data_shape)) else: net = symbol_factory.get_symbol_train(args.network, args.data_shape, num_classes=args.num_classes) print(net.tojson())
def train_net(net, train_path, num_classes, batch_size, data_shape, mean_pixels, resume, finetune, pretrained, epoch, prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate, momentum, weight_decay, use_plateau, lr_refactor_step, lr_refactor_ratio, use_global_stats=0, freeze_layer_pattern='', num_example=10000, label_pad_width=350, nms_thresh=0.45, force_nms=False, ovp_thresh=0.5, use_difficult=False, class_names=None, ignore_names=None, optimizer_name='sgd', voc07_metric=False, nms_topk=400, force_suppress=False, train_list="", val_path="", val_list="", iter_monitor=0, monitor_pattern=".*", log_file=None): """ Wrapper for training phase. Parameters: ---------- net : str symbol name for the network structure train_path : str record file path for training num_classes : int number of object classes, not including background batch_size : int training batch-size data_shape : int or tuple width/height as integer or (3, height, width) tuple mean_pixels : tuple of floats mean pixel values for red, green and blue resume : int resume from previous checkpoint if > 0 finetune : int fine-tune from previous checkpoint if > 0 pretrained : str prefix of pretrained model, including path epoch : int load epoch of either resume/finetune/pretrained model prefix : str prefix for saving checkpoints ctx : [mx.cpu()] or [mx.gpu(x)] list of mxnet contexts begin_epoch : int starting epoch for training, should be 0 if not otherwise specified end_epoch : int end epoch of training frequent : int frequency to print out training status learning_rate : float training learning rate momentum : float trainig momentum weight_decay : float training weight decay param lr_refactor_ratio : float multiplier for reducing learning rate lr_refactor_step : comma separated integers at which epoch to rescale learning rate, e.g. '30, 60, 90' freeze_layer_pattern : str regex pattern for layers need to be fixed num_example : int number of training images label_pad_width : int force padding training and validation labels to sync their label widths nms_thresh : float non-maximum suppression threshold for validation force_nms : boolean suppress overlaped objects from different classes train_list : str list file path for training, this will replace the embeded labels in record val_path : str record file path for validation val_list : str list file path for validation, this will replace the embeded labels in record iter_monitor : int monitor internal stats in networks if > 0, specified by monitor_pattern monitor_pattern : str regex pattern for monitoring network stats log_file : str log to file if enabled """ # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) if log_file: fh = logging.FileHandler(log_file) logger.addHandler(fh) # check args if isinstance(data_shape, int): data_shape = (3, data_shape, data_shape) assert len(data_shape) == 3 and data_shape[0] == 3 prefix += '_' + net + '_' + str(data_shape[1]) if isinstance(mean_pixels, (int, float)): mean_pixels = [mean_pixels, mean_pixels, mean_pixels] assert len(mean_pixels) == 3, "must provide all RGB mean values" train_iter = DetRecordIter(train_path, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=label_pad_width, path_imglist=train_list, **cfg.train) if val_path: val_iter = DetRecordIter(val_path, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=label_pad_width, path_imglist=val_list, **cfg.valid) else: val_iter = None # load symbol net_str = net net = get_symbol_train(net, data_shape[1], \ use_global_stats=use_global_stats, \ num_classes=num_classes, ignore_names=ignore_names, \ nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk) # define layers with fixed weight/bias if freeze_layer_pattern.strip(): re_prog = re.compile(freeze_layer_pattern) fixed_param_names = [ name for name in net.list_arguments() if re_prog.match(name) ] else: fixed_param_names = None # load pretrained or resume from previous state ctx_str = '(' + ','.join([str(c) for c in ctx]) + ')' if resume > 0: logger.info("Resume training with {} from epoch {}".format( ctx_str, resume)) _, args, auxs = mx.model.load_checkpoint(prefix, resume) begin_epoch = resume elif finetune > 0: logger.info("Start finetuning with {} from epoch {}".format( ctx_str, finetune)) _, args, auxs = mx.model.load_checkpoint(prefix, finetune) begin_epoch = finetune # the prediction convolution layers name starts with relu, so it's fine fixed_param_names = [name for name in net.list_arguments() \ if name.startswith('conv')] elif pretrained: try: logger.info( "Start training with {} from pretrained model {}".format( ctx_str, pretrained)) _, args, auxs = mx.model.load_checkpoint(pretrained, epoch) args = convert_pretrained(pretrained, args) if net_str == 'ssd_pva': args, auxs = convert_pvanet(args, auxs) except: logger.info( "Failed to load the pretrained model. Start from scratch.") args = None auxs = None fixed_param_names = None else: logger.info("Experimental: start training from scratch with {}".format( ctx_str)) args = None auxs = None fixed_param_names = None # helper information if fixed_param_names: logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']') # init training module if not use_plateau: # focal loss does not go well with plateau mod = mx.mod.Module(net, label_names=('label', ), logger=logger, context=ctx, fixed_param_names=fixed_param_names) else: mod = PlateauModule(net, label_names=('label', ), logger=logger, context=ctx, fixed_param_names=fixed_param_names) # robust parameter setting mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label) mod = set_mod_params(mod, args, auxs, logger) # fit parameters batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent, auto_reset=True) epoch_end_callback = mx.callback.do_checkpoint(prefix) monitor = mx.mon.Monitor( iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None optimizer_params = { 'learning_rate': learning_rate, 'wd': weight_decay, 'clip_gradient': 4.0, 'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0 } if optimizer_name == 'sgd': optimizer_params['momentum'] = momentum # #7847 mod.init_optimizer(optimizer=optimizer_name, optimizer_params=optimizer_params, force_init=True) if not use_plateau: learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio, num_example, batch_size, begin_epoch) else: w_l1 = cfg.train['smoothl1_weight'] eval_weights = { 'CrossEntropy': 1.0, 'SmoothL1': w_l1, 'ObjectRecall': 0.0 } plateau_lr = PlateauScheduler( \ patient_epochs=lr_refactor_step, factor=float(lr_refactor_ratio), eval_weights=eval_weights) plateau_metric = MultiBoxMetric( fn_stat='/home/hyunjoon/github/additions_mxnet/ssd/stat.txt') mod.init_optimizer(optimizer=optimizer_name, optimizer_params=optimizer_params) eval_metric = MultiBoxMetric() # run fit net, every n epochs we run evaluation network to get mAP if voc07_metric: map_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=4) recall_metric = RecallMetric(ovp_thresh, use_difficult, pred_idx=4) valid_metric = mx.metric.create([map_metric, recall_metric]) else: valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=4) if not use_plateau: mod.fit(train_iter, eval_data=val_iter, eval_metric=eval_metric, validation_metric=valid_metric, batch_end_callback=batch_end_callback, epoch_end_callback=epoch_end_callback, optimizer=optimizer_name, optimizer_params=optimizer_params, begin_epoch=begin_epoch, num_epoch=end_epoch, initializer=mx.init.Xavier(), arg_params=args, aux_params=auxs, allow_missing=True, monitor=monitor) else: mod.fit(train_iter, plateau_lr, plateau_metric=plateau_metric, fn_curr_model=prefix + '-1000.params', plateau_backtrace=False, eval_data=val_iter, eval_metric=eval_metric, validation_metric=valid_metric, validation_period=5, kvstore='local', batch_end_callback=batch_end_callback, epoch_end_callback=epoch_end_callback, optimizer=optimizer_name, optimizer_params=optimizer_params, begin_epoch=begin_epoch, num_epoch=end_epoch, initializer=mx.init.Xavier(), arg_params=args, aux_params=auxs, allow_missing=True, monitor=monitor)
def train_net(net, train_path, num_classes, batch_size, data_shape, mean_pixels, resume, finetune, pretrained, epoch, prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate, momentum, weight_decay, lr_refactor_step, lr_refactor_ratio, freeze_layer_pattern='', num_example=10000, label_pad_width=350, nms_thresh=0.45, force_nms=False, ovp_thresh=0.5, use_difficult=False, class_names=None, voc07_metric=False, nms_topk=2000, force_suppress=False, train_list="", val_path="", val_list="", iter_monitor=0, monitor_pattern=".*", log_file=None): """ Wrapper for training phase. Parameters: ---------- net : str symbol name for the network structure train_path : str record file path for training num_classes : int number of object classes, not including background batch_size : int training batch-size data_shape : int or tuple width/height as integer or (3, height, width) tuple mean_pixels : tuple of floats mean pixel values for red, green and blue resume : int resume from previous checkpoint if > 0 finetune : int fine-tune from previous checkpoint if > 0 pretrained : str prefix of pretrained model, including path epoch : int load epoch of either resume/finetune/pretrained model prefix : str prefix for saving checkpoints ctx : [mx.cpu()] or [mx.gpu(x)] list of mxnet contexts begin_epoch : int starting epoch for training, should be 0 if not otherwise specified end_epoch : int end epoch of training frequent : int frequency to print out training status learning_rate : float training learning rate momentum : float trainig momentum weight_decay : float training weight decay param lr_refactor_ratio : float multiplier for reducing learning rate lr_refactor_step : comma separated integers at which epoch to rescale learning rate, e.g. '30, 60, 90' freeze_layer_pattern : str regex pattern for layers need to be fixed num_example : int number of training images label_pad_width : int force padding training and validation labels to sync their label widths nms_thresh : float non-maximum suppression threshold for validation force_nms : boolean suppress overlaped objects from different classes train_list : str list file path for training, this will replace the embeded labels in record val_path : str record file path for validation val_list : str list file path for validation, this will replace the embeded labels in record iter_monitor : int monitor internal stats in networks if > 0, specified by monitor_pattern monitor_pattern : str regex pattern for monitoring network stats log_file : str log to file if enabled """ # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) # set a log if log_file: # crate a fileHandler fh = logging.FileHandler(log_file) logger.addHandler(fh) # check args if isinstance(data_shape, int): data_shape = (3, data_shape, data_shape) assert len(data_shape) == 3 and data_shape[0] == 3 prefix += '_' + net + '_' + str(data_shape[1]) # check the mean_pixels is list if isinstance(mean_pixels, (int, float)): mean_pixels = [mean_pixels, mean_pixels, mean_pixels] assert len(mean_pixels) == 3, "must provide all RGB mean values" train_iter = DetRecordIter(train_path, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=label_pad_width, path_imglist=train_list, **cfg.train) # for c in range(12840): # batch = train_iter.next() # data=batch.data[0] # label=batch.label[0] # from matplotlib import pyplot as plt # import numpy as np # import cv2 # for i in range(2): # plt.subplot(1,2,i+1) # img = np.array(data[i].asnumpy().transpose(1,2,0).copy(), np.uint8) # box = label[i].asnumpy() # bbox = [] # print 'The', i, 'th image' # for j in range(box.shape[0]): # if box[j][0] == -1: # break # else: # bbox.append(box[j][1:5]) # for k in range(len(bbox)): # xmin = (bbox[k][0] * img.shape[0]).astype(np.int16) # ymin = (bbox[k][1] * img.shape[0]).astype(np.int16) # xmax = (bbox[k][2] * img.shape[0]).astype(np.int16) # ymax = (bbox[k][3] * img.shape[0]).astype(np.int16) # cv2.rectangle(img, (xmin,ymin), (xmax,ymax), (255,0,0),4) # # print 'xmin', xmin, 'ymin', ymin, 'xmax', xmax, 'ymax', ymax # plt.imshow(img) # plt.show() # #path = 'crop_image/'+ str(c) + '.jpg' # #plt.savefig(path) # print batch if val_path: val_iter = DetRecordIter(val_path, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=label_pad_width, path_imglist=val_list, **cfg.valid) else: val_iter = None # load symbol net = get_symbol_train(net, data_shape[1], num_classes=num_classes, nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk) # viz = mx.viz.plot_network(net) # viz.view() # define layers with fixed weight/bias if freeze_layer_pattern.strip(): re_prog = re.compile(freeze_layer_pattern) fixed_param_names = [ name for name in net.list_arguments() if re_prog.match(name) ] else: fixed_param_names = None # load pretrained or resume from previous state ctx_str = '(' + ','.join([str(c) for c in ctx]) + ')' if resume > 0: logger.info("Resume training with {} from epoch {}".format( ctx_str, resume)) _, args, auxs = mx.model.load_checkpoint(prefix, resume) begin_epoch = resume elif finetune > 0: logger.info("Start finetuning with {} from epoch {}".format( ctx_str, finetune)) _, args, auxs = mx.model.load_checkpoint(prefix, finetune) begin_epoch = finetune # the prediction convolution layers name starts with relu, so it's fine fixed_param_names = [name for name in net.list_arguments() \ if name.startswith('conv')] elif pretrained: logger.info("Start training with {} from pretrained model {}".format( ctx_str, pretrained)) fixed_param_names = None _, args, auxs = mx.model.load_checkpoint(pretrained, epoch) args = convert_pretrained(pretrained, args) else: logger.info("Experimental: start training from scratch with {}".format( ctx_str)) args = None auxs = None fixed_param_names = None # helper information if fixed_param_names: logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']') # init training module mod = mx.mod.Module(net, label_names=('label', ), logger=logger, context=ctx, fixed_param_names=fixed_param_names) # fit parameters batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent) epoch_end_callback = mx.callback.do_checkpoint(prefix) learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio, num_example, batch_size, begin_epoch) optimizer_params = { 'learning_rate': learning_rate, 'momentum': momentum, 'wd': weight_decay, 'lr_scheduler': lr_scheduler, 'clip_gradient': None, 'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0 } monitor = mx.mon.Monitor( iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None # run fit net, every n epochs we run evaluation network to get mAP if voc07_metric: valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3) else: valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3) mod.fit( train_data=train_iter, #train_iter, eval_data=val_iter, eval_metric=MultiBoxMetric(), validation_metric=valid_metric, batch_end_callback=batch_end_callback, epoch_end_callback=epoch_end_callback, optimizer='sgd', optimizer_params=optimizer_params, begin_epoch=begin_epoch, num_epoch=end_epoch, initializer=mx.init.Xavier(), arg_params=args, aux_params=auxs, allow_missing=True, monitor=monitor)