コード例 #1
0
def test(net,
         val_data,
         batch_fn,
         data_source_needs_reset,
         dtype,
         ctx,
         calc_weight_count=False,
         extended_log=False):
    rmse_calc = mx.metric.RMSE()

    tic = time.time()
    rmse_val_value = validate(metric_calc=rmse_calc,
                              net=net,
                              val_data=val_data,
                              batch_fn=batch_fn,
                              data_source_needs_reset=data_source_needs_reset,
                              dtype=dtype,
                              ctx=ctx)
    if calc_weight_count:
        weight_count = calc_net_weight_count(net)
        logging.info('Model: {} trainable parameters'.format(weight_count))
    if extended_log:
        logging.info(
            'Test: rmse={rmse:.4f} ({rmse})'.format(rmse=rmse_val_value))
    else:
        logging.info('Test: rmse={rmse:.4f}'.format(rmse=rmse_val_value))
    logging.info('Time cost: {:.4f} sec'.format(time.time() - tic))
コード例 #2
0
ファイル: eval_gl.py プロジェクト: jdc08161063/imgclsmob
def test(net,
         val_data,
         batch_fn,
         use_rec,
         dtype,
         ctx,
         calc_weight_count=False,
         extended_log=False):
    acc_top1 = mx.metric.Accuracy()
    acc_top5 = mx.metric.TopKAccuracy(5)

    tic = time.time()
    err_top1_val, err_top5_val = validate(acc_top1=acc_top1,
                                          acc_top5=acc_top5,
                                          net=net,
                                          val_data=val_data,
                                          batch_fn=batch_fn,
                                          use_rec=use_rec,
                                          dtype=dtype,
                                          ctx=ctx)
    if calc_weight_count:
        weight_count = calc_net_weight_count(net)
        logging.info('Model: {} trainable parameters'.format(weight_count))
    if extended_log:
        logging.info(
            'Test: err-top1={top1:.4f} ({top1})\terr-top5={top5:.4f} ({top5})'.
            format(top1=err_top1_val, top5=err_top5_val))
    else:
        logging.info('Test: err-top1={top1:.4f}\terr-top5={top5:.4f}'.format(
            top1=err_top1_val, top5=err_top5_val))
    logging.info('Time cost: {:.4f} sec'.format(time.time() - tic))
コード例 #3
0
def test(net,
         val_data,
         batch_fn,
         data_source_needs_reset,
         dtype,
         ctx,
         input_image_size,
         in_channels,
         calc_weight_count=False,
         calc_flops=False,
         calc_flops_only=True,
         extended_log=False):
    if not calc_flops_only:
        acc_top1 = mx.metric.Accuracy()
        acc_top5 = mx.metric.TopKAccuracy(5)
        tic = time.time()
        err_top1_val, err_top5_val = validate(
            acc_top1=acc_top1,
            acc_top5=acc_top5,
            net=net,
            val_data=val_data,
            batch_fn=batch_fn,
            data_source_needs_reset=data_source_needs_reset,
            dtype=dtype,
            ctx=ctx)
        if extended_log:
            logging.info(
                'Test: err-top1={top1:.4f} ({top1})\terr-top5={top5:.4f} ({top5})'
                .format(top1=err_top1_val, top5=err_top5_val))
        else:
            logging.info(
                'Test: err-top1={top1:.4f}\terr-top5={top5:.4f}'.format(
                    top1=err_top1_val, top5=err_top5_val))
        logging.info('Time cost: {:.4f} sec'.format(time.time() - tic))

    if calc_weight_count:
        weight_count = calc_net_weight_count(net)
        if not calc_flops:
            logging.info("Model: {} trainable parameters".format(weight_count))
    if calc_flops:
        num_flops, num_macs, num_params = measure_model(
            net, in_channels, input_image_size, ctx[0])
        assert (not calc_weight_count) or (weight_count == num_params)
        stat_msg = "Params: {params} ({params_m:.2f}M), FLOPs: {flops} ({flops_m:.2f}M)," \
                   " FLOPs/2: {flops2} ({flops2_m:.2f}M), MACs: {macs} ({macs_m:.2f}M)"
        logging.info(
            stat_msg.format(params=num_params,
                            params_m=num_params / 1e6,
                            flops=num_flops,
                            flops_m=num_flops / 1e6,
                            flops2=num_flops / 2,
                            flops2_m=num_flops / 2 / 1e6,
                            macs=num_macs,
                            macs_m=num_macs / 1e6))
コード例 #4
0
def test(net,
         test_data,
         batch_fn,
         data_source_needs_reset,
         metric,
         dtype,
         ctx,
         input_image_size,
         in_channels,
         calc_weight_count=False,
         calc_flops=False,
         calc_flops_only=True,
         extended_log=False):
    if not calc_flops_only:
        tic = time.time()
        validate(metric=metric,
                 net=net,
                 val_data=test_data,
                 batch_fn=batch_fn,
                 data_source_needs_reset=data_source_needs_reset,
                 dtype=dtype,
                 ctx=ctx)
        accuracy_msg = report_accuracy(metric=metric,
                                       extended_log=extended_log)
        logging.info("Test: {}".format(accuracy_msg))
        logging.info("Time cost: {:.4f} sec".format(time.time() - tic))

    if calc_weight_count:
        weight_count = calc_net_weight_count(net)
        if not calc_flops:
            logging.info("Model: {} trainable parameters".format(weight_count))
    if calc_flops:
        num_flops, num_macs, num_params = measure_model(
            net, in_channels, input_image_size, ctx[0])
        assert (not calc_weight_count) or (weight_count == num_params)
        stat_msg = "Params: {params} ({params_m:.2f}M), FLOPs: {flops} ({flops_m:.2f}M)," \
                   " FLOPs/2: {flops2} ({flops2_m:.2f}M), MACs: {macs} ({macs_m:.2f}M)"
        logging.info(
            stat_msg.format(params=num_params,
                            params_m=num_params / 1e6,
                            flops=num_flops,
                            flops_m=num_flops / 1e6,
                            flops2=num_flops / 2,
                            flops2_m=num_flops / 2 / 1e6,
                            macs=num_macs,
                            macs_m=num_macs / 1e6))
コード例 #5
0
def test(net,
         test_data,
         data_source_needs_reset,
         dtype,
         ctx,
         input_image_size,
         in_channels,
         classes,
         calc_weight_count=False,
         calc_flops=False,
         calc_flops_only=True,
         extended_log=False,
         dataset_metainfo=None):
    assert (dataset_metainfo is not None)
    if not calc_flops_only:
        metric = mx.metric.CompositeEvalMetric()
        pix_acc_macro_average = False
        metric.add(PixelAccuracyMetric(
            vague_idx=dataset_metainfo["vague_idx"],
            use_vague=dataset_metainfo["use_vague"],
            macro_average=pix_acc_macro_average))
        mean_iou_macro_average = False
        metric.add(MeanIoUMetric(
            num_classes=classes,
            vague_idx=dataset_metainfo["vague_idx"],
            use_vague=dataset_metainfo["use_vague"],
            bg_idx=dataset_metainfo["background_idx"],
            ignore_bg=dataset_metainfo["ignore_bg"],
            macro_average=mean_iou_macro_average))
        tic = time.time()
        accuracy_info = validate1(
            accuracy_metric=metric,
            net=net,
            val_data=test_data,
            batch_fn=batch_fn,
            data_source_needs_reset=data_source_needs_reset,
            dtype=dtype,
            ctx=ctx)
        pix_acc = accuracy_info[1][0]
        mean_iou = accuracy_info[1][1]
        pix_macro = "macro" if pix_acc_macro_average else "micro"
        iou_macro = "macro" if mean_iou_macro_average else "micro"
        if extended_log:
            logging.info(
                "Test: {pix_macro}-pix_acc={pix_acc:.4f} ({pix_acc}), "
                "{iou_macro}-mean_iou={mean_iou:.4f} ({mean_iou})".format(
                    pix_macro=pix_macro, pix_acc=pix_acc, iou_macro=iou_macro, mean_iou=mean_iou))
        else:
            logging.info("Test: {pix_macro}-pix_acc={pix_acc:.4f}, {iou_macro}-mean_iou={mean_iou:.4f}".format(
                pix_macro=pix_macro, pix_acc=pix_acc, iou_macro=iou_macro, mean_iou=mean_iou))
        logging.info("Time cost: {:.4f} sec".format(
            time.time() - tic))

    if calc_weight_count:
        weight_count = calc_net_weight_count(net)
        if not calc_flops:
            logging.info("Model: {} trainable parameters".format(weight_count))
    if calc_flops:
        num_flops, num_macs, num_params = measure_model(net, in_channels, input_image_size, ctx[0])
        assert (not calc_weight_count) or (weight_count == num_params)
        stat_msg = "Params: {params} ({params_m:.2f}M), FLOPs: {flops} ({flops_m:.2f}M)," \
                   " FLOPs/2: {flops2} ({flops2_m:.2f}M), MACs: {macs} ({macs_m:.2f}M)"
        logging.info(stat_msg.format(
            params=num_params, params_m=num_params / 1e6,
            flops=num_flops, flops_m=num_flops / 1e6,
            flops2=num_flops / 2, flops2_m=num_flops / 2 / 1e6,
            macs=num_macs, macs_m=num_macs / 1e6))
コード例 #6
0
ファイル: eval_gl.py プロジェクト: siddie/imgclsmob
def calc_model_accuracy(net,
                        test_data,
                        batch_fn,
                        data_source_needs_reset,
                        metric,
                        dtype,
                        ctx,
                        input_image_size,
                        in_channels,
                        calc_weight_count=False,
                        calc_flops=False,
                        calc_flops_only=True,
                        extended_log=False):
    """
    Main test routine.

    Parameters:
    ----------
    net : HybridBlock
        Model.
    test_data : DataLoader or ImageRecordIter
        Data loader or ImRec-iterator.
    batch_fn : func
        Function for splitting data after extraction from data loader.
    data_source_needs_reset : bool
        Whether to reset data (if test_data is ImageRecordIter).
    metric : EvalMetric
        Metric object instance.
    dtype : str
        Base data type for tensors.
    ctx : Context
        MXNet context.
    input_image_size : tuple of 2 ints
        Spatial size of the expected input image.
    in_channels : int
        Number of input channels.
    calc_weight_count : bool, default False
        Whether to calculate count of weights.
    calc_flops : bool, default False
        Whether to calculate FLOPs.
    calc_flops_only : bool, default True
        Whether to only calculate FLOPs without testing.
    extended_log : bool, default False
        Whether to log more precise accuracy values.

    Returns:
    -------
    list of floats
        Accuracy values.
    """
    if not calc_flops_only:
        tic = time.time()
        validate(metric=metric,
                 net=net,
                 val_data=test_data,
                 batch_fn=batch_fn,
                 data_source_needs_reset=data_source_needs_reset,
                 dtype=dtype,
                 ctx=ctx)
        accuracy_msg = report_accuracy(metric=metric,
                                       extended_log=extended_log)
        logging.info("Test: {}".format(accuracy_msg))
        logging.info("Time cost: {:.4f} sec".format(time.time() - tic))
        acc_values = metric.get()[1]
        acc_values = acc_values if type(acc_values) == list else [acc_values]
    else:
        acc_values = []

    if calc_weight_count:
        weight_count = calc_net_weight_count(net)
        if not calc_flops:
            logging.info("Model: {} trainable parameters".format(weight_count))
    if calc_flops:
        num_flops, num_macs, num_params = measure_model(
            net, in_channels, input_image_size, ctx[0])
        assert (not calc_weight_count) or (weight_count == num_params)
        stat_msg = "Params: {params} ({params_m:.2f}M), FLOPs: {flops} ({flops_m:.2f}M)," \
                   " FLOPs/2: {flops2} ({flops2_m:.2f}M), MACs: {macs} ({macs_m:.2f}M)"
        logging.info(
            stat_msg.format(params=num_params,
                            params_m=num_params / 1e6,
                            flops=num_flops,
                            flops_m=num_flops / 1e6,
                            flops2=num_flops / 2,
                            flops2_m=num_flops / 2 / 1e6,
                            macs=num_macs,
                            macs_m=num_macs / 1e6))

    return acc_values