示例#1
0
def _wrapper_count_operators(model: nn.Module, inputs: list, mode: str,
                             **kwargs) -> typing.DefaultDict[str, float]:
    # ignore some ops
    supported_ops = {k: lambda *args, **kwargs: {} for k in _IGNORED_OPS}
    supported_ops.update(kwargs.pop("supported_ops", {}))
    kwargs["supported_ops"] = supported_ops

    assert len(inputs) == 1, "Please use batch size=1"
    tensor_input = inputs[0]["image"]
    inputs = [{
        "image": tensor_input
    }]  # remove other keys, in case there are any

    old_train = model.training
    if isinstance(
            model,
        (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)):
        model = model.module
    wrapper = TracingAdapter(model, inputs)
    wrapper.eval()
    if mode == FLOPS_MODE:
        ret = flop_count(wrapper, (tensor_input, ), **kwargs)
    elif mode == ACTIVATIONS_MODE:
        ret = activation_count(wrapper, (tensor_input, ), **kwargs)
    else:
        raise NotImplementedError(
            "Count for mode {} is not supported yet.".format(mode))
    # compatible with change in fvcore
    if isinstance(ret, tuple):
        ret = ret[0]
    model.train(old_train)
    return ret
示例#2
0
def _wrapper_count_operators(model: nn.Module, inputs: list, mode: str,
                             **kwargs) -> typing.DefaultDict[str, float]:

    assert len(inputs) == 1, "Please use batch size=1"
    tensor_input = inputs[0]["image"]

    class WrapModel(nn.Module):
        def __init__(self, model):
            super().__init__()
            if isinstance(model,
                          (nn.parallel.distributed.DistributedDataParallel,
                           nn.DataParallel)):
                self.model = model.module
            else:
                self.model = model

        def forward(self, image):
            # jit requires the input/output to be Tensors
            inputs = [{"image": image}]
            outputs = self.model.forward(inputs)[0]
            if isinstance(outputs, dict) and "instances" in outputs:
                # Only the subgraph that computes the returned tensor will be
                # counted. So we return everything we found in Instances.
                inst = outputs["instances"]
                ret = [inst.pred_boxes.tensor]
                inst.remove("pred_boxes")
                for k, v in inst.get_fields().items():
                    if isinstance(v, torch.Tensor):
                        ret.append(v)
                    else:
                        log_first_n(
                            logging.WARN,
                            f"Field '{k}' in output instances is not included"
                            " in flops/activations count.",
                            n=10,
                        )
                return tuple(ret)
            raise NotImplementedError(
                "Count for segmentation models is not supported yet.")

    old_train = model.training
    with torch.no_grad():
        if mode == FLOPS_MODE:
            ret = flop_count(
                WrapModel(model).train(False), (tensor_input, ), **kwargs)
        elif mode == ACTIVATIONS_MODE:
            ret = activation_count(
                WrapModel(model).train(False), (tensor_input, ), **kwargs)
        else:
            raise NotImplementedError(
                "Count for mode {} is not supported yet.".format(mode))
    # compatible with change in fvcore
    if isinstance(ret, tuple):
        ret = ret[0]
    model.train(old_train)
    return ret
示例#3
0
def flop_count_operators(model: nn.Module, inputs: list,
                         **kwargs) -> typing.DefaultDict[str, float]:
    """
    Implement operator-level flops counting using jit.
    This is a wrapper of fvcore.nn.flop_count, that supports standard detection models
    in detectron2.

    Note:
        The function runs the input through the model to compute flops.
        The flops of a detection model is often input-dependent, for example,
        the flops of box & mask head depends on the number of proposals &
        the number of detected objects.
        Therefore, the flops counting using a single input may not accurately
        reflect the computation cost of a model.

    Args:
        model: a detectron2 model that takes `list[dict]` as input.
        inputs (list[dict]): inputs to model, in detectron2's standard format.
    """
    assert len(inputs) == 1, "Please only compute flops with batch size=1"
    tensor_input = inputs[0]["image"]

    class WrapModel(nn.Module):
        def __init__(self, model):
            super().__init__()
            self.model = model

        def forward(self, image):
            # jit requires the input/output to be Tensors
            inputs = [{"image": image}]
            outputs = self.model.forward(inputs)[0]
            if isinstance(outputs, dict) and "instances" in outputs:
                # Only the subgraph that computes the returned tensor will be
                # counted. So we return everything we found in Instances.
                inst = outputs["instances"]
                ret = [inst.pred_boxes.tensor]
                inst.remove("pred_boxes")
                for k, v in inst.get_fields().items():
                    if isinstance(v, torch.Tensor):
                        ret.append(v)
                    else:
                        log_first_n(
                            logging.WARN,
                            f"Field '{k}' in output instances is not included in flops counting.",
                            n=10,
                        )
                return tuple(ret)
            raise NotImplementedError(
                "Flops for segmentation models not supported yet.")

    old_train = model.training
    with torch.no_grad():
        ret = flop_count(
            WrapModel(model).train(False), (tensor_input, ), **kwargs)
    model.train(old_train)
    return ret
示例#4
0
def _wrapper_count_operators(model: nn.Module, inputs: list, mode: str,
                             **kwargs) -> typing.DefaultDict[str, float]:

    # ignore some ops
    supported_ops = {k: lambda *args, **kwargs: {} for k in _IGNORED_OPS}
    supported_ops.update(kwargs.pop("supported_ops", {}))
    kwargs["supported_ops"] = supported_ops

    assert len(inputs) == 1, "Please use batch size=1"
    tensor_input = inputs[0]["image"]

    class WrapModel(nn.Module):
        def __init__(self, model):
            super().__init__()
            if isinstance(model,
                          (nn.parallel.distributed.DistributedDataParallel,
                           nn.DataParallel)):
                self.model = model.module
            else:
                self.model = model

        def forward(self, image):
            # jit requires the input/output to be Tensors
            inputs = [{"image": image}]
            outputs = self.model.forward(inputs)
            # Only the subgraph that computes the returned tuple of tensor will be
            # counted. So we flatten everything we found to tuple of tensors.
            return _flatten_to_tuple(outputs)

    old_train = model.training
    with torch.no_grad():
        if mode == FLOPS_MODE:
            ret = flop_count(
                WrapModel(model).train(False), (tensor_input, ), **kwargs)
        elif mode == ACTIVATIONS_MODE:
            ret = activation_count(
                WrapModel(model).train(False), (tensor_input, ), **kwargs)
        else:
            raise NotImplementedError(
                "Count for mode {} is not supported yet.".format(mode))
    # compatible with change in fvcore
    if isinstance(ret, tuple):
        ret = ret[0]
    model.train(old_train)
    return ret