def test(net, val_data, batch_fn, data_source_needs_reset, dtype, ctx, input_image_size, in_channels, calc_weight_count=False, calc_flops=False, calc_flops_only=True, extended_log=False): if not calc_flops_only: acc_top1 = mx.metric.Accuracy() acc_top5 = mx.metric.TopKAccuracy(5) tic = time.time() err_top1_val, err_top5_val = validate( acc_top1=acc_top1, acc_top5=acc_top5, net=net, val_data=val_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx) if extended_log: logging.info( 'Test: err-top1={top1:.4f} ({top1})\terr-top5={top5:.4f} ({top5})' .format(top1=err_top1_val, top5=err_top5_val)) else: logging.info( 'Test: err-top1={top1:.4f}\terr-top5={top5:.4f}'.format( top1=err_top1_val, top5=err_top5_val)) logging.info('Time cost: {:.4f} sec'.format(time.time() - tic)) if calc_weight_count: weight_count = calc_net_weight_count(net) if not calc_flops: logging.info("Model: {} trainable parameters".format(weight_count)) if calc_flops: num_flops, num_macs, num_params = measure_model( net, in_channels, input_image_size, ctx[0]) assert (not calc_weight_count) or (weight_count == num_params) stat_msg = "Params: {params} ({params_m:.2f}M), FLOPs: {flops} ({flops_m:.2f}M)," \ " FLOPs/2: {flops2} ({flops2_m:.2f}M), MACs: {macs} ({macs_m:.2f}M)" logging.info( stat_msg.format(params=num_params, params_m=num_params / 1e6, flops=num_flops, flops_m=num_flops / 1e6, flops2=num_flops / 2, flops2_m=num_flops / 2 / 1e6, macs=num_macs, macs_m=num_macs / 1e6))
def test(net, test_data, batch_fn, data_source_needs_reset, metric, dtype, ctx, input_image_size, in_channels, calc_weight_count=False, calc_flops=False, calc_flops_only=True, extended_log=False): if not calc_flops_only: tic = time.time() validate(metric=metric, net=net, val_data=test_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx) accuracy_msg = report_accuracy(metric=metric, extended_log=extended_log) logging.info("Test: {}".format(accuracy_msg)) logging.info("Time cost: {:.4f} sec".format(time.time() - tic)) if calc_weight_count: weight_count = calc_net_weight_count(net) if not calc_flops: logging.info("Model: {} trainable parameters".format(weight_count)) if calc_flops: num_flops, num_macs, num_params = measure_model( net, in_channels, input_image_size, ctx[0]) assert (not calc_weight_count) or (weight_count == num_params) stat_msg = "Params: {params} ({params_m:.2f}M), FLOPs: {flops} ({flops_m:.2f}M)," \ " FLOPs/2: {flops2} ({flops2_m:.2f}M), MACs: {macs} ({macs_m:.2f}M)" logging.info( stat_msg.format(params=num_params, params_m=num_params / 1e6, flops=num_flops, flops_m=num_flops / 1e6, flops2=num_flops / 2, flops2_m=num_flops / 2 / 1e6, macs=num_macs, macs_m=num_macs / 1e6))
def test(net, test_data, data_source_needs_reset, dtype, ctx, input_image_size, in_channels, classes, calc_weight_count=False, calc_flops=False, calc_flops_only=True, extended_log=False, dataset_metainfo=None): assert (dataset_metainfo is not None) if not calc_flops_only: metric = mx.metric.CompositeEvalMetric() pix_acc_macro_average = False metric.add(PixelAccuracyMetric( vague_idx=dataset_metainfo["vague_idx"], use_vague=dataset_metainfo["use_vague"], macro_average=pix_acc_macro_average)) mean_iou_macro_average = False metric.add(MeanIoUMetric( num_classes=classes, vague_idx=dataset_metainfo["vague_idx"], use_vague=dataset_metainfo["use_vague"], bg_idx=dataset_metainfo["background_idx"], ignore_bg=dataset_metainfo["ignore_bg"], macro_average=mean_iou_macro_average)) tic = time.time() accuracy_info = validate1( accuracy_metric=metric, net=net, val_data=test_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx) pix_acc = accuracy_info[1][0] mean_iou = accuracy_info[1][1] pix_macro = "macro" if pix_acc_macro_average else "micro" iou_macro = "macro" if mean_iou_macro_average else "micro" if extended_log: logging.info( "Test: {pix_macro}-pix_acc={pix_acc:.4f} ({pix_acc}), " "{iou_macro}-mean_iou={mean_iou:.4f} ({mean_iou})".format( pix_macro=pix_macro, pix_acc=pix_acc, iou_macro=iou_macro, mean_iou=mean_iou)) else: logging.info("Test: {pix_macro}-pix_acc={pix_acc:.4f}, {iou_macro}-mean_iou={mean_iou:.4f}".format( pix_macro=pix_macro, pix_acc=pix_acc, iou_macro=iou_macro, mean_iou=mean_iou)) logging.info("Time cost: {:.4f} sec".format( time.time() - tic)) if calc_weight_count: weight_count = calc_net_weight_count(net) if not calc_flops: logging.info("Model: {} trainable parameters".format(weight_count)) if calc_flops: num_flops, num_macs, num_params = measure_model(net, in_channels, input_image_size, ctx[0]) assert (not calc_weight_count) or (weight_count == num_params) stat_msg = "Params: {params} ({params_m:.2f}M), FLOPs: {flops} ({flops_m:.2f}M)," \ " FLOPs/2: {flops2} ({flops2_m:.2f}M), MACs: {macs} ({macs_m:.2f}M)" logging.info(stat_msg.format( params=num_params, params_m=num_params / 1e6, flops=num_flops, flops_m=num_flops / 1e6, flops2=num_flops / 2, flops2_m=num_flops / 2 / 1e6, macs=num_macs, macs_m=num_macs / 1e6))
def calc_model_accuracy(net, test_data, batch_fn, data_source_needs_reset, metric, dtype, ctx, input_image_size, in_channels, calc_weight_count=False, calc_flops=False, calc_flops_only=True, extended_log=False): """ Main test routine. Parameters: ---------- net : HybridBlock Model. test_data : DataLoader or ImageRecordIter Data loader or ImRec-iterator. batch_fn : func Function for splitting data after extraction from data loader. data_source_needs_reset : bool Whether to reset data (if test_data is ImageRecordIter). metric : EvalMetric Metric object instance. dtype : str Base data type for tensors. ctx : Context MXNet context. input_image_size : tuple of 2 ints Spatial size of the expected input image. in_channels : int Number of input channels. calc_weight_count : bool, default False Whether to calculate count of weights. calc_flops : bool, default False Whether to calculate FLOPs. calc_flops_only : bool, default True Whether to only calculate FLOPs without testing. extended_log : bool, default False Whether to log more precise accuracy values. Returns: ------- list of floats Accuracy values. """ if not calc_flops_only: tic = time.time() validate(metric=metric, net=net, val_data=test_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx) accuracy_msg = report_accuracy(metric=metric, extended_log=extended_log) logging.info("Test: {}".format(accuracy_msg)) logging.info("Time cost: {:.4f} sec".format(time.time() - tic)) acc_values = metric.get()[1] acc_values = acc_values if type(acc_values) == list else [acc_values] else: acc_values = [] if calc_weight_count: weight_count = calc_net_weight_count(net) if not calc_flops: logging.info("Model: {} trainable parameters".format(weight_count)) if calc_flops: num_flops, num_macs, num_params = measure_model( net, in_channels, input_image_size, ctx[0]) assert (not calc_weight_count) or (weight_count == num_params) stat_msg = "Params: {params} ({params_m:.2f}M), FLOPs: {flops} ({flops_m:.2f}M)," \ " FLOPs/2: {flops2} ({flops2_m:.2f}M), MACs: {macs} ({macs_m:.2f}M)" logging.info( stat_msg.format(params=num_params, params_m=num_params / 1e6, flops=num_flops, flops_m=num_flops / 1e6, flops2=num_flops / 2, flops2_m=num_flops / 2 / 1e6, macs=num_macs, macs_m=num_macs / 1e6)) return acc_values