def get_symbol(network, data_shape, **kwargs): """Wrapper for get symbol for test Parameters ---------- network : str name for the base network symbol data_shape : int input shape kwargs : dict see symbol_builder.get_symbol for more details """ if network.startswith('legacy'): logging.warn('Using legacy model.') return symbol_builder.import_module(network).get_symbol(**kwargs) config = get_config(network, data_shape, **kwargs).copy() config.update(kwargs) return symbol_builder.get_symbol(**config)
def evaluate_net(net, path_imgrec, num_classes, mean_pixels, data_shape, model_prefix, epoch, ctx=mx.cpu(), batch_size=1, path_imglist="", nms_thresh=0.45, force_nms=False, ovp_thresh=0.5, use_difficult=False, class_names=None, voc07_metric=False): """ evalute network given validation record file Parameters: ---------- net : str or None Network name or use None to load from json without modifying path_imgrec : str path to the record validation file path_imglist : str path to the list file to replace labels in record file, optional num_classes : int number of classes, not including background mean_pixels : tuple (mean_r, mean_g, mean_b) data_shape : tuple or int (3, height, width) or height/width model_prefix : str model prefix of saved checkpoint epoch : int load model epoch ctx : mx.ctx mx.gpu() or mx.cpu() batch_size : int validation batch size nms_thresh : float non-maximum suppression threshold force_nms : boolean whether suppress different class objects ovp_thresh : float AP overlap threshold for true/false postives use_difficult : boolean whether to use difficult objects in evaluation if applicable class_names : comma separated str class names in string, must correspond to num_classes if set voc07_metric : boolean whether to use 11-point evluation as in VOC07 competition """ # set up logger logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) # args if isinstance(data_shape, int): data_shape = (3, data_shape, data_shape) assert len(data_shape) == 3 and data_shape[0] == 3 # model_prefix += '_' + str(data_shape[1]) # iterator eval_iter = DetRecordIter(path_imgrec, batch_size, data_shape, path_imglist=path_imglist, **cfg.valid) # model params load_net, args, auxs = mx.model.load_checkpoint(model_prefix, epoch) # network if net is None: net = load_net else: sys.path.append(os.path.join(cfg.ROOT_DIR, 'symbol')) net = get_symbol('symbol_' + net, num_classes, nms_thresh, force_nms) # net = importlib.import_module("symbol_" + net) \ # .get_symbol(num_classes, nms_thresh, force_nms) if not 'yolo_output_label' in net.list_arguments(): label = mx.sym.Variable(name='yolo_output_label') net = mx.sym.Group([net, label]) # init module mod = mx.mod.Module(net, label_names=('yolo_output_label', ), logger=logger, context=ctx, fixed_param_names=net.list_arguments()) mod.bind(data_shapes=eval_iter.provide_data, label_shapes=eval_iter.provide_label) mod.set_params(args, auxs, allow_missing=False, force_init=True) # run evaluation if voc07_metric: metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names) else: metric = MApMetric(ovp_thresh, use_difficult, class_names) results = mod.score(eval_iter, metric, num_batch=None) for k, v in results: print("{}: {}".format(k, v))