def get_symbol_train(network, data_shape, **kwargs): """Wrapper for get symbol for train Parameters ---------- network : str name for the base network symbol data_shape : int input shape kwargs : dict see symbol_builder.get_symbol_train for more details """ if network.startswith('legacy'): logging.warn('Using legacy model.') return symbol_builder.import_module(network).get_symbol_train(**kwargs) config = get_config(network, data_shape, **kwargs).copy() config.update(kwargs) return symbol_builder.get_symbol_train(**config)
def train_net(net, train_path, num_classes, batch_size, data_shape, mean_pixels, resume, finetune, pretrained, epoch, prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate, momentum, weight_decay, lr_refactor_step, lr_refactor_ratio, optimizer_name='nadam', freeze_layer_pattern='', shape_range=(320, 512), random_shape_step=0, random_shape_epoch=10, num_example=10000, label_pad_width=350, nms_thresh=0.45, force_nms=False, ovp_thresh=0.5, use_difficult=False, class_names=None, voc07_metric=False, nms_topk=400, force_suppress=False, train_list="", val_path="", val_list="", iter_monitor=0, monitor_pattern=".*", log_file=None): """ Wrapper for training phase. Parameters: ---------- net : str symbol name for the network structure train_path : str record file path for training num_classes : int number of object classes, not including background batch_size : int training batch-size data_shape : int or tuple width/height as integer or (3, height, width) tuple mean_pixels : tuple of floats mean pixel values for red, green and blue resume : int resume from previous checkpoint if > 0 finetune : int fine-tune from previous checkpoint if > 0 pretrained : str prefix of pretrained model, including path epoch : int load epoch of either resume/finetune/pretrained model prefix : str prefix for saving checkpoints ctx : [mx.cpu()] or [mx.gpu(x)] list of mxnet contexts begin_epoch : int starting epoch for training, should be 0 if not otherwise specified end_epoch : int end epoch of training frequent : int frequency to print out training status learning_rate : float training learning rate momentum : float trainig momentum weight_decay : float training weight decay param lr_refactor_ratio : float multiplier for reducing learning rate lr_refactor_step : comma separated integers at which epoch to rescale learning rate, e.g. '30, 60, 90' freeze_layer_pattern : str regex pattern for layers need to be fixed shape_range : tuple of (min, max) random data shape range random_shape_step : int step size for random data shape, defined by network, 0 to disable random_step_epoch : int number of epoch before next random shape num_example : int number of training images label_pad_width : int force padding training and validation labels to sync their label widths nms_thresh : float non-maximum suppression threshold for validation force_nms : boolean suppress overlaped objects from different classes train_list : str list file path for training, this will replace the embeded labels in record val_path : str record file path for validation val_list : str list file path for validation, this will replace the embeded labels in record iter_monitor : int monitor internal stats in networks if > 0, specified by monitor_pattern monitor_pattern : str regex pattern for monitoring network stats log_file : str log to file if enabled """ # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) if log_file: fh = logging.FileHandler(log_file) logger.addHandler(fh) # check args if isinstance(data_shape, int): data_shape = (3, data_shape, data_shape) assert len(data_shape) == 3 and data_shape[0] == 3 prefix += '_' + net.strip('_yolo') + '_' + str(data_shape[1]) if isinstance(mean_pixels, (int, float)): mean_pixels = [mean_pixels, mean_pixels, mean_pixels] assert len(mean_pixels) == 3, "must provide all RGB mean values" # load symbol sys.path.append(os.path.join(cfg.ROOT_DIR, 'symbol')) # symbol_module = importlib.import_module("symbol_" + net) # net = symbol_module.get_symbol(num_classes, nms_thresh=nms_thresh, # force_suppress=force_suppress, nms_topk=nms_topk) net = get_symbol_train(net, num_classes, nms_thresh, force_suppress, nms_topk) # define layers with fixed weight/bias if freeze_layer_pattern.strip(): re_prog = re.compile(freeze_layer_pattern) fixed_param_names = [name for name in net.list_arguments() if re_prog.match(name)] else: fixed_param_names = None # load pretrained or resume from previous state ctx_str = '('+ ','.join([str(c) for c in ctx]) + ')' allow_missing = True if resume > 0: logger.info("Resume training with {} from epoch {}" .format(ctx_str, resume)) _, args, auxs = mx.model.load_checkpoint(prefix, resume) begin_epoch = resume allow_missing = False elif finetune > 0: logger.info("Start finetuning with {} from epoch {}" .format(ctx_str, finetune)) _, args, auxs = mx.model.load_checkpoint(prefix, finetune) begin_epoch = finetune # the prediction convolution layers name starts with relu, so it's fine fixed_param_names = [name for name in net.list_arguments() \ if name.startswith('conv')] elif pretrained: logger.info("Start training with {} from pretrained model {}" .format(ctx_str, pretrained)) _, args, auxs = mx.model.load_checkpoint(pretrained, epoch) args = convert_pretrained(pretrained, args) else: logger.info("Experimental: start training from scratch with {}" .format(ctx_str)) args = None auxs = None fixed_param_names = None # helper information if fixed_param_names: logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']') # fit parameters batch_end_callback = mx.callback.Speedometer(batch_size, frequent=frequent) epoch_end_callback = mx.callback.do_checkpoint(prefix) monitor = mx.mon.Monitor(iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None # run fit net, every n epochs we run evaluation network to get mAP if voc07_metric: valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=4) else: valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=4) # init training module mod = mx.mod.Module(net, label_names=('yolo_output_label',), logger=logger, context=ctx, fixed_param_names=fixed_param_names) random_shape_step = int(random_shape_step) if random_shape_step > 0: fit_begins = range(begin_epoch, end_epoch, random_shape_epoch) fit_ends = fit_begins[1:] + [end_epoch] assert (len(shape_range) == 2) data_shapes = [(3, x * random_shape_step, x * random_shape_step) \ for x in range(shape_range[0] // random_shape_step, shape_range[1] // random_shape_step + 1)] logger.info("Candidate random shapes:" + str(data_shapes)) else: fit_begins = [begin_epoch] fit_ends = [end_epoch] data_shapes = [data_shape] learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio, num_example, batch_size, begin_epoch) optimizer_params={'learning_rate': learning_rate, 'wd': weight_decay, 'lr_scheduler': lr_scheduler, 'clip_gradient': 4.0, 'rescale_grad': 1.0 } if optimizer_name in ('sgd', 'nag'): optimizer_params['momentum'] = momentum # optimizer_params={'learning_rate':learning_rate, # 'momentum':momentum, # 'wd':weight_decay, # 'lr_scheduler':lr_scheduler, # 'clip_gradient':10, # 'rescale_grad': 1.0 } for begin, end in zip(fit_begins, fit_ends): if len(data_shapes) == 1: data_shape = data_shapes[0] else: data_shape = data_shapes[random.randint(0, len(data_shapes)-1)] logger.info("Setting random data shape: " + str(data_shape)) train_iter = DetRecordIter(train_path, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=label_pad_width, path_imglist=train_list, **cfg.train) if val_path: val_iter = DetRecordIter(val_path, batch_size, data_shape, mean_pixels=mean_pixels, label_pad_width=label_pad_width, path_imglist=val_list, **cfg.valid) else: val_iter = None # more informatic parameter setting if not mod.binded: mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label) mod = set_mod_params(mod, args, auxs, logger) mod.fit(train_iter, val_iter, eval_metric=MultiBoxMetric(), validation_metric=valid_metric, batch_end_callback=batch_end_callback, epoch_end_callback=epoch_end_callback, optimizer=optimizer_name, optimizer_params=optimizer_params, begin_epoch=begin, num_epoch=end, initializer=mx.init.Xavier(), arg_params=args, aux_params=auxs, allow_missing=allow_missing, monitor=monitor, force_rebind=True, force_init=True) args, auxs = mod.get_params() allow_missing = False