コード例 #1
0
def dump_flops_info(model, inputs, output_dir):
    """
    Dump flops information about model, using the given model inputs.
    Information are dumped to output_dir using various flop counting tools
    in different formats. Only a simple table is printed to terminal.
    """
    if not comm.is_main_process():
        return
    logger.info("Evaluating model's number of parameters and FLOPS")
    model = copy.deepcopy(model)
    model.eval()

    # 1. using mobile_cv flop counter
    try:
        fest = flops_utils.FlopsEstimation(model)
        with fest.enable():
            model(inputs)
            fest.add_flops_info()
            model_str = str(model)
        output_file = os.path.join(output_dir, "flops_str_mobilecv.txt")
        with PathManager.open(output_file, "w") as f:
            f.write(model_str)
            logger.info(f"Flops info written to {output_file}")
    except Exception:
        logger.exception(
            "Failed to estimate flops using mobile_cv's FlopsEstimation")

    # 2. using d2/fvcore's flop counter
    try:
        flops = FlopCountAnalysis(model, inputs)

        # 2.1: dump as model str
        model_str = flop_count_str(flops)
        output_file = os.path.join(output_dir, "flops_str_fvcore.txt")
        with PathManager.open(output_file, "w") as f:
            f.write(model_str)
            logger.info(f"Flops info written to {output_file}")

        # 2.2: dump as table
        flops_table = flop_count_table(flops, max_depth=10)
        output_file = os.path.join(output_dir, "flops_table_fvcore.txt")
        with PathManager.open(output_file, "w") as f:
            f.write(flops_table)
            logger.info(f"Flops table written to {output_file}")

        # 2.3: print a table with a shallow depth
        flops_table = flop_count_table(flops, max_depth=3)
        logger.info("Flops table:\n" + flops_table)
    except Exception:
        logger.exception(
            "Failed to estimate flops using detectron2's FlopCountAnalysis")
コード例 #2
0
def do_flop(cfg):
    if isinstance(cfg, CfgNode):
        data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0])
        model = build_model(cfg)
        DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS)
    else:
        data_loader = instantiate(cfg.dataloader.test)
        model = instantiate(cfg.model)
        model.to(cfg.train.device)
        DetectionCheckpointer(model).load(cfg.train.init_checkpoint)
    model.eval()

    counts = Counter()
    total_flops = []
    for idx, data in zip(tqdm.trange(args.num_inputs), data_loader):  # noqa
        flops = FlopCountAnalysis(model, data)
        if idx > 0:
            flops.unsupported_ops_warnings(False).uncalled_modules_warnings(
                False)
        counts += flops.by_operator()
        total_flops.append(flops.total())

    logger.info("Flops table computed from only one input sample:\n" +
                flop_count_table(flops))
    logger.info("Average GFlops for each type of operators:\n" +
                str([(k, v / (idx + 1) / 1e9) for k, v in counts.items()]))
    logger.info("Total GFlops: {:.1f}±{:.1f}".format(
        np.mean(total_flops) / 1e9,
        np.std(total_flops) / 1e9))
コード例 #3
0
 def test_detr_fbnet_export(self):
     runner = create_runner("d2go.projects.detr.runner.DETRRunner")
     cfg = runner.get_default_cfg()
     cfg.MODEL.DEVICE = "cpu"
     # DETR
     self._set_detr_cfg(cfg, 3, 3, 50, 256)
     # backbone
     cfg.MODEL.BACKBONE.NAME = "FBNetV2C4Backbone"
     cfg.MODEL.FBNET_V2.ARCH = "FBNetV3_A_dsmask_C5"
     cfg.MODEL.FBNET_V2.WIDTH_DIVISOR = 8
     cfg.MODEL.FBNET_V2.OUT_FEATURES = ["trunk4"]
     # build model
     model = runner.build_model(cfg).eval()
     model = model.detr
     print(model)
     scripted_model = torch.jit.script(model)
     self._assert_model_output(model, scripted_model)
     # print flops
     table = flop_count_table(
         FlopCountAnalysis(model, ([torch.rand(3, 224, 320)], )))
     print(table)
コード例 #4
0
def dump_flops_info(model, inputs, output_dir, use_eval_mode=True):
    """
    Dump flops information about model, using the given model inputs.
    Information are dumped to output_dir using various flop counting tools
    in different formats. Only a simple table is printed to terminal.

    Args:
        inputs: a tuple of positional arguments used to call model with.
        use_eval_mode: turn the model into eval mode for flop counting. Otherwise,
            will use the original mode. It's recommended to use eval mode, because
            training mode typically follows a different codepath.
    """
    if not comm.is_main_process():
        return
    logger.info("Evaluating model's number of parameters and FLOPS")

    try:
        model = copy.deepcopy(model)
    except Exception:
        logger.info("Failed to deepcopy the model and skip FlopsEstimation.")
        return

    # delete other forward_pre_hooks so they are not simultaneously called
    for k in model._forward_pre_hooks:
        del model._forward_pre_hooks[k]

    if use_eval_mode:
        model.eval()
    inputs = copy.deepcopy(inputs)

    # 1. using mobile_cv flop counter
    try:
        fest = flops_utils.FlopsEstimation(model)
        with fest.enable():
            model(*inputs)
            fest.add_flops_info()
            model_str = str(model)
        output_file = os.path.join(output_dir, "flops_str_mobilecv.txt")
        with PathManager.open(output_file, "w") as f:
            f.write(model_str)
            logger.info(f"Flops info written to {output_file}")
    except Exception:
        logger.exception(
            "Failed to estimate flops using mobile_cv's FlopsEstimation")

    # 2. using d2/fvcore's flop counter
    output_file = os.path.join(output_dir, "flops_str_fvcore.txt")
    try:
        flops = FlopCountAnalysis(model, inputs)

        # 2.1: dump as model str
        model_str = flop_count_str(flops)
        with PathManager.open(output_file, "w") as f:
            f.write(model_str)
            logger.info(f"Flops info written to {output_file}")

        # 2.2: dump as table
        flops_table = flop_count_table(flops, max_depth=10)
        output_file = os.path.join(output_dir, "flops_table_fvcore.txt")
        with PathManager.open(output_file, "w") as f:
            f.write(flops_table)
            logger.info(f"Flops table (full version) written to {output_file}")

        # 2.3: print a table with a shallow depth
        flops_table = flop_count_table(flops, max_depth=3)
        logger.info("Flops table:\n" + flops_table)
    except Exception:
        with PathManager.open(output_file, "w") as f:
            traceback.print_exc(file=f)
        logger.warning(
            "Failed to estimate flops using detectron2's FlopCountAnalysis. "
            f"Error written to {output_file}.")
        flops = float("nan")
    return flops