Ejemplo n.º 1
0
    def test_print_shape(self):
        model = create_model2()
        input = torch.zeros(model.OP_INPUT)
        model_str_before = str(model)
        fest = flops_utils.FlopsEstimation(model)
        model_str_after = str(model)
        # make sure str was not affected when it is not enabled
        self.assertEqual(model_str_before, model_str_after)

        with fest.enable():
            model(input)
            fest.add_flops_info()
            model_str = str(model)
            print(model_str)

        GT_SHAPES_STRS = [
            "input_shapes=[[2, 3, 16, 16]], output_shapes=[2, 8, 16, 16]",
            "input_shapes=[[2, 8, 8, 8]], output_shapes=[2, 4, 8, 8]",
            "ReLU(input_shapes=[[2, 4, 8, 8]], output_shapes=[2, 4, 8, 8], nparams=0.0, nflops=0.0)",  # noqa
            "nparams=0.00108, nflops=0.221184",
        ]
        for x in GT_SHAPES_STRS:
            self.assertIn(x, model_str)

        # make sure the additional informaiton are cleaned up
        model_str_clean = str(model)
        GT_SHAPES = ["input_shapes", "output_shapes", "nparams", "nflops"]
        for x in GT_SHAPES:
            self.assertNotIn(x, model_str_clean)
Ejemplo n.º 2
0
    def test_duplicated(self):
        """Make sure handles subclasses propertly for mock"""
        model = nn.Sequential(M1(), M2(), nn.Conv2d(3, 4, 3),
                              nn.ConvTranspose2d(4, 4, 3), M1())
        input = torch.zeros([1, 3, 4, 4])

        fest = flops_utils.FlopsEstimation(model)

        count = 0
        flops = []

        def flops_callback(fest, model, model_data):
            nonlocal count
            nparams, nflops = fest.get_flops()
            flops.append({"nparams": nparams, "nflops": nflops})
            if count >= 2:
                fest.set_enable(False)
            count += 1

        fest.set_callback(flops_callback)

        fest.set_enable(True)
        for _ in range(5):
            model(input)

        gt_flops = [{"nparams": 0.000252, "nflops": 0.002736}] * 3
        self.assertEqual(gt_flops, flops)
Ejemplo n.º 3
0
    def test_callback(self):
        model = create_model2()
        input = torch.zeros(model.OP_INPUT)

        fest = flops_utils.FlopsEstimation(model)

        count = 0
        flops = []

        def flops_callback(fest, model, model_data):
            nonlocal count
            nparams, nflops = fest.get_flops()
            flops.append({"nparams": nparams, "nflops": nflops})
            if count >= 2:
                fest.set_enable(False)
            count += 1

        fest.set_callback(flops_callback)

        fest.set_enable(True)
        for _ in range(5):
            model(input)

        gt_flops = [{"nparams": 0.00108, "nflops": 0.221184}] * 3
        self.assertEqual(gt_flops, flops)
Ejemplo n.º 4
0
def print_flops(model, first_batch):
    logger.info("Evaluating model's number of parameters and FLOPS")
    model_flops = copy.deepcopy(model)
    model_flops.eval()
    fest = flops_utils.FlopsEstimation(model_flops)
    with fest.enable():
        model_flops(first_batch)
        fest.add_flops_info()
        model_str = str(model_flops)
        logger.info(model_str)
    return model_str
Ejemplo n.º 5
0
    def test_get_flops(self):
        model = create_model2()
        fest = flops_utils.FlopsEstimation(model)

        input = torch.zeros(model.OP_INPUT)
        with fest.enable():
            model(input)
            nparams, nflops = fest.get_flops()

        self.assertAlmostEqual(nparams, 0.00108)
        self.assertAlmostEqual(nflops, 0.221184)
Ejemplo n.º 6
0
def dump_flops_info(model, inputs, output_dir):
    """
    Dump flops information about model, using the given model inputs.
    Information are dumped to output_dir using various flop counting tools
    in different formats. Only a simple table is printed to terminal.
    """
    if not comm.is_main_process():
        return
    logger.info("Evaluating model's number of parameters and FLOPS")
    model = copy.deepcopy(model)
    model.eval()

    # 1. using mobile_cv flop counter
    try:
        fest = flops_utils.FlopsEstimation(model)
        with fest.enable():
            model(inputs)
            fest.add_flops_info()
            model_str = str(model)
        output_file = os.path.join(output_dir, "flops_str_mobilecv.txt")
        with PathManager.open(output_file, "w") as f:
            f.write(model_str)
            logger.info(f"Flops info written to {output_file}")
    except Exception:
        logger.exception(
            "Failed to estimate flops using mobile_cv's FlopsEstimation")

    # 2. using d2/fvcore's flop counter
    try:
        flops = FlopCountAnalysis(model, inputs)

        # 2.1: dump as model str
        model_str = flop_count_str(flops)
        output_file = os.path.join(output_dir, "flops_str_fvcore.txt")
        with PathManager.open(output_file, "w") as f:
            f.write(model_str)
            logger.info(f"Flops info written to {output_file}")

        # 2.2: dump as table
        flops_table = flop_count_table(flops, max_depth=10)
        output_file = os.path.join(output_dir, "flops_table_fvcore.txt")
        with PathManager.open(output_file, "w") as f:
            f.write(flops_table)
            logger.info(f"Flops table written to {output_file}")

        # 2.3: print a table with a shallow depth
        flops_table = flop_count_table(flops, max_depth=3)
        logger.info("Flops table:\n" + flops_table)
    except Exception:
        logger.exception(
            "Failed to estimate flops using detectron2's FlopCountAnalysis")
Ejemplo n.º 7
0
def add_print_flops_callback(cfg, model, disable_after_callback=True):
    def _print_flops_callback(self, model, model_data):
        self.add_flops_info()
        logger.info("Callback: model flops info:\n{}".format(model))

        def _guess_batch_size():
            # Inputs are meta-arch dependent, the most general solution will be
            # adding a function like `get_batch_size()` to each meta arch
            ret = 1
            try:
                model_input_shapes = model_data(model)["input_shapes"]
                assert isinstance(model_input_shapes, list)
                assert len(model_input_shapes) > 0
                # assuming the first input is a list of images
                ret = len(model_input_shapes[0])
            except Exception:
                ret = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size()
                logger.warning(
                    "Could not get batch size, compute from"
                    f" `cfg.SOLVER.IMS_PER_BATCH`={ret}"
                )
                pass

            return ret

        nparams, nflops = self.get_flops()
        batch_size = _guess_batch_size()
        nflops_single = nflops / batch_size
        logger.info(
            f"Model parameters (M): {nparams}, "
            f"MFlops (batch_size={batch_size}): {nflops}, "
            f"MFlops (batch_size=1): {nflops_single}"
        )

        if disable_after_callback:
            self.set_enable(False)

    fest = flops_utils.FlopsEstimation(model).set_callback(_print_flops_callback)
    logger.info("Added callback to log flops info after the first inference")
    fest.set_enable(True)
    return fest
Ejemplo n.º 8
0
def dump_flops_info(model, inputs, output_dir, use_eval_mode=True):
    """
    Dump flops information about model, using the given model inputs.
    Information are dumped to output_dir using various flop counting tools
    in different formats. Only a simple table is printed to terminal.

    Args:
        inputs: a tuple of positional arguments used to call model with.
        use_eval_mode: turn the model into eval mode for flop counting. Otherwise,
            will use the original mode. It's recommended to use eval mode, because
            training mode typically follows a different codepath.
    """
    if not comm.is_main_process():
        return
    logger.info("Evaluating model's number of parameters and FLOPS")

    try:
        model = copy.deepcopy(model)
    except Exception:
        logger.info("Failed to deepcopy the model and skip FlopsEstimation.")
        return

    # delete other forward_pre_hooks so they are not simultaneously called
    for k in model._forward_pre_hooks:
        del model._forward_pre_hooks[k]

    if use_eval_mode:
        model.eval()
    inputs = copy.deepcopy(inputs)

    # 1. using mobile_cv flop counter
    try:
        fest = flops_utils.FlopsEstimation(model)
        with fest.enable():
            model(*inputs)
            fest.add_flops_info()
            model_str = str(model)
        output_file = os.path.join(output_dir, "flops_str_mobilecv.txt")
        with PathManager.open(output_file, "w") as f:
            f.write(model_str)
            logger.info(f"Flops info written to {output_file}")
    except Exception:
        logger.exception(
            "Failed to estimate flops using mobile_cv's FlopsEstimation")

    # 2. using d2/fvcore's flop counter
    output_file = os.path.join(output_dir, "flops_str_fvcore.txt")
    try:
        flops = FlopCountAnalysis(model, inputs)

        # 2.1: dump as model str
        model_str = flop_count_str(flops)
        with PathManager.open(output_file, "w") as f:
            f.write(model_str)
            logger.info(f"Flops info written to {output_file}")

        # 2.2: dump as table
        flops_table = flop_count_table(flops, max_depth=10)
        output_file = os.path.join(output_dir, "flops_table_fvcore.txt")
        with PathManager.open(output_file, "w") as f:
            f.write(flops_table)
            logger.info(f"Flops table (full version) written to {output_file}")

        # 2.3: print a table with a shallow depth
        flops_table = flop_count_table(flops, max_depth=3)
        logger.info("Flops table:\n" + flops_table)
    except Exception:
        with PathManager.open(output_file, "w") as f:
            traceback.print_exc(file=f)
        logger.warning(
            "Failed to estimate flops using detectron2's FlopCountAnalysis. "
            f"Error written to {output_file}.")
        flops = float("nan")
    return flops