def test_fbnet_v2_scriptable(self):
        e6 = {"expansion": 6}
        bn_args = {"bn_args": {"name": "bn", "momentum": 0.003}}
        arch_def = {
            "blocks": [
                # [op, c, s, n, ...]
                # stage 0
                [("conv_k3", 4, 2, 1, bn_args)],
                # stage 1
                [
                    ("ir_k3", 8, 2, 2, e6, bn_args),
                    ("ir_k5", 8, 1, 1, e6, bn_args),
                ],
            ]
        }

        model = _build_model(arch_def, dim_in=3)
        model.eval()
        model = fuse_utils.fuse_model(model, inplace=False)

        print(f"Fused model {model}")

        # Make sure model can be traced
        script_model = torch.jit.script(model)

        data = torch.zeros(1, 3, 32, 32)
        model_output = model(data)
        script_output = script_model(data)

        self.assertEqual(model_output.norm(), script_output.norm())

        with tempfile.TemporaryDirectory() as tmp_dir:
            fn = os.path.join(tmp_dir, "model_script.jit")
            torch.jit.save(script_model, fn)
            self.assertTrue(os.path.isfile(fn))
    def test_fbnet_v2_scriptable_empty_batch(self):
        e6 = {"expansion": 6}
        bn_args = {"bn_args": {"name": "bn", "momentum": 0.003}}
        arch_def = {
            "blocks": [
                # [op, c, s, n, ...]
                # stage 0
                [("conv_k3", 4, 2, 1, bn_args)],
                # stage 1
                [
                    ("ir_k3", 8, 2, 2, e6, bn_args),
                    ("ir_k5", 8, 1, 1, e6, bn_args),
                ],
            ]
        }

        model = _build_model(arch_def, dim_in=3)
        model.eval()
        model = fuse_utils.fuse_model(model, inplace=False)

        # Make sure model can be traced
        script_model = torch.jit.script(model)

        # empty batch
        data = torch.zeros(0, 3, 32, 32)
        script_output = script_model(data)

        self.assertEqual(script_output.shape, torch.Size([0, 8, 8, 8]))
示例#3
0
    def test_fuse_model_swish(self):
        e6 = {"expansion": 6}
        dw_skip_bnrelu = {"dw_skip_bnrelu": True}
        bn_args = {"bn_args": {"name": "bn", "momentum": 0.003}}
        arch_def = {
            "blocks": [
                # [c, s, n, ...]
                # stage 0
                [("conv_k3", 4, 2, 1, bn_args, {
                    "relu_args": "swish"
                })],
                # stage 1
                [
                    ("ir_k3", 8, 2, 2, e6, dw_skip_bnrelu, bn_args),
                    ("ir_k5_sehsig", 8, 1, 1, e6, bn_args),
                ],
            ]
        }

        model = _build_model(arch_def, dim_in=3)
        fused_model = fuse_utils.fuse_model(model, inplace=False)
        print(model)
        print(fused_model)

        self.assertTrue(_find_modules(model, torch.nn.BatchNorm2d))
        self.assertFalse(_find_modules(fused_model, torch.nn.BatchNorm2d))
        self.assertTrue(_find_modules(fused_model, bb.Swish))

        input_size = [2, 3, 8, 8]
        run_and_compare(model, fused_model, input_size)
示例#4
0
def export_to_torchscript(args, task, model, inputs, output_base_dir,
                          **kwargs):
    output_dir = os.path.join(output_base_dir, "torchscript")
    with torch.no_grad():
        fused_model = fuse_utils.fuse_model(model, inplace=False)
    print("fused model {}".format(fused_model))
    torch_script_path = trace_and_save_torchscript(
        fused_model,
        inputs,
        output_dir,
        use_get_traceable=bool(args.use_get_traceable),
        trace_type=args.trace_type,
        opt_for_mobile=args.opt_for_mobile,
    )
    return torch_script_path
示例#5
0
def default_prepare_for_quant(cfg, model):
    """
    Default implementation of preparing a model for quantization. This function will
    be called to before training if QAT is enabled, or before calibration during PTQ if
    the model is not already quantized.

    NOTE:
        - This is the simplest implementation, most meta-arch needs its own version.
        - For eager model, user should make sure the returned model has Quant/DeQuant
            insert. This can be done by wrapping the model or defining the model with
            quant stubs.
        - QAT/PTQ can be determined by model.training.
        - Currently the input model can be changed inplace since we won't re-use the
            input model.
        - Currently this API doesn't include the final torch.ao.quantization.prepare(_qat)
            call since existing usecases don't have further steps after it.

    Args:
        model (nn.Module): a non-quantized model.
        cfg (CfgNode): config

    Return:
        nn.Module: a ready model for QAT training or PTQ calibration
    """
    qconfig = set_backend_and_create_qconfig(cfg, is_train=model.training)

    if cfg.QUANTIZATION.EAGER_MODE:
        model = fuse_utils.fuse_model(
            model,
            is_qat=cfg.QUANTIZATION.QAT.ENABLED,
            inplace=True,
        )

        model.qconfig = qconfig
        # TODO(future diff): move the torch.ao.quantization.prepare(...) call
        # here, to be consistent with the FX branch
    else:  # FX graph mode quantization
        qconfig_dict = {"": qconfig}
        # TODO[quant-example-inputs]: needs follow up to change the api
        example_inputs = (torch.rand(1, 3, 3, 3), )
        if model.training:
            model = prepare_qat_fx(model, qconfig_dict, example_inputs)
        else:
            model = prepare_fx(model, qconfig_dict, example_inputs)

    logger.info("Setup the model with qconfig:\n{}".format(qconfig))

    return model
示例#6
0
def convert_torch_script(
    model, inputs, fuse_bn=True, verify_output=True, use_get_traceable=False
):
    assert isinstance(inputs, (tuple, list)), f"Invalid input types {inputs}"
    if verify_output:
        print("Run pytorch model")
        with torch.no_grad():
            output_before = model(*inputs)

    if fuse_bn:
        print("Fusing bn...")
        fused_model = fuse_utils.fuse_model(model)
        if fuse_utils.check_bn_exist(fused_model):
            print(f"WARNING: BN existed after fusing, {fused_model}")
    else:
        fused_model = copy.deepcopy(model)

    for x in fused_model.parameters():
        x.requires_grad = False

    if use_get_traceable:
        print("Get traceable model...")
        fused_model = ju.get_traceable_model(fused_model)

    print("Start tracing...")
    with torch.no_grad():
        traced_model = torch.jit.trace(fused_model, inputs, strict=False)
    # print(f"Traced model {traced_model}")
    # print(f"Traced model {traced_model.code}")

    # print("Optimizing traced model...")
    # traced_model = optimize_for_mobile(traced_model)
    print("Generating traced model lints...")
    print(generate_mobile_module_lints(traced_model))

    print("Run traced model")
    with torch.no_grad():
        outputs = traced_model(*inputs)

    if verify_output:
        paired_outputs = iu.create_pair(output_before, outputs)
        for x in iu.recursive_iterate(paired_outputs, iter_types=torch.Tensor):
            np.testing.assert_allclose(
                x.lhs.detach(), x.rhs.detach(), rtol=0, atol=1e-4
            )

    return traced_model, outputs
示例#7
0
def d2_meta_arch_prepare_for_quant(self, cfg):
    model = self

    # Modify the model for eager mode
    if cfg.QUANTIZATION.EAGER_MODE:
        model = _apply_eager_mode_quant(cfg, model)

    model = fuse_utils.fuse_model(model, inplace=True)

    torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND
    model.qconfig = (torch.quantization.get_default_qat_qconfig(
        cfg.QUANTIZATION.BACKEND) if model.training else
                     torch.quantization.get_default_qconfig(
                         cfg.QUANTIZATION.BACKEND))
    logger.info("Setup the model with qconfig:\n{}".format(model.qconfig))

    return model
示例#8
0
    def test_post_quant(self):
        e6 = {"expansion": 6}
        dw_skip_bnrelu = {"dw_skip_bnrelu": True}
        bn_args = {"bn_args": {"name": "bn", "momentum": 0.003}}
        arch_def = {
            "blocks": [
                # [op, c, s, n, ...]
                # stage 0
                [("conv_k3", 4, 2, 1, bn_args)],
                # stage 1
                [
                    ("ir_k3_sehsig", 8, 2, 2, e6, dw_skip_bnrelu, bn_args),
                    ("ir_k5", 8, 1, 1, e6, bn_args),
                ],
            ]
        }

        model = _build_model(arch_def, dim_in=3)
        model = torch.quantization.QuantWrapper(model)
        model = fuse_utils.fuse_model(model, inplace=False)

        print(f"Fused model {model}")

        model.qconfig = torch.quantization.default_qconfig
        print(model.qconfig)
        torch.quantization.prepare(model, inplace=True)

        # calibration
        for _ in range(5):
            data = torch.rand([2, 3, 8, 8])
            model(data)

        # Convert to quantized model
        quant_model = torch.quantization.convert(model, inplace=False)
        print(f"Quant model {quant_model}")

        # Run quantized model
        quant_output = quant_model(torch.rand([2, 3, 8, 8]))
        self.assertEqual(quant_output.shape, torch.Size([2, 8, 2, 2]))

        # Trace quantized model
        jit_model = torch.jit.trace(quant_model, data)
        jit_quant_output = jit_model(torch.rand([2, 3, 8, 8]))
        self.assertEqual(jit_quant_output.shape, torch.Size([2, 8, 2, 2]))
示例#9
0
    def test_fuse_model_with_reference(self):
        model = ModelWithRef()
        model.eval()
        self.assertEqual(id(model.cbr), id(model.cbr_list[0]))

        # keep the same reference after copying
        model_copy = copy.deepcopy(model)
        self.assertEqual(id(model_copy.cbr), id(model_copy.cbr_list[0]))

        fused_model = fuse_utils.fuse_model(model, inplace=False)
        self.assertEqual(id(fused_model.cbr), id(fused_model.cbr_list[0]))
        print(model)
        print(fused_model)

        self.assertTrue(_find_modules(model, torch.nn.BatchNorm2d))
        self.assertFalse(_find_modules(fused_model, torch.nn.BatchNorm2d))

        input_size = [2, 3, 8, 8]
        run_and_compare(model, fused_model, input_size)
示例#10
0
def default_rcnn_prepare_for_quant(self, cfg):
    model = self
    torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND
    model.qconfig = (torch.quantization.get_default_qat_qconfig(
        cfg.QUANTIZATION.BACKEND) if model.training else
                     torch.quantization.get_default_qconfig(
                         cfg.QUANTIZATION.BACKEND))
    if (hasattr(model, "roi_heads") and hasattr(model.roi_heads, "mask_head")
            and isinstance(model.roi_heads.mask_head, PointRendMaskHead)):
        model.roi_heads.mask_head.qconfig = None
    logger.info("Setup the model with qconfig:\n{}".format(model.qconfig))

    # Modify the model for eager mode
    if cfg.QUANTIZATION.EAGER_MODE:
        model = _apply_eager_mode_quant(cfg, model)
        model = fuse_utils.fuse_model(model, inplace=True)
    else:
        _fx_quant_prepare(model, cfg)

    return model
示例#11
0
文件: api.py 项目: stevenchang8/d2go
def convert_and_export_predictor(cfg, pytorch_model, predictor_type,
                                 output_dir, data_loader):
    """
    Entry point for convert and export model. This involves two steps:
        - convert: converting the given `pytorch_model` to another format, currently
            mainly for quantizing the model.
        - export: exporting the converted `pytorch_model` to predictor. This step
            should not alter the behaviour of model.
    """
    if "int8" in predictor_type:
        if not cfg.QUANTIZATION.QAT.ENABLED:
            logger.info(
                "The model is not quantized during training, running post"
                " training quantization ...")
            pytorch_model = post_training_quantize(cfg, pytorch_model,
                                                   data_loader)
            # only check bn exists in ptq as qat still has bn inside fused ops
            assert not fuse_utils.check_bn_exist(pytorch_model)
        logger.info(
            f"Converting quantized model {cfg.QUANTIZATION.BACKEND}...")
        if cfg.QUANTIZATION.EAGER_MODE:
            # TODO(future diff): move this logic to prepare_for_quant_convert
            pytorch_model = torch.quantization.convert(pytorch_model,
                                                       inplace=False)
        else:  # FX graph mode quantization
            if hasattr(pytorch_model, "prepare_for_quant_convert"):
                pytorch_model = pytorch_model.prepare_for_quant_convert(cfg)
            else:
                # TODO(future diff): move this to a default function
                pytorch_model = torch.quantization.quantize_fx.convert_fx(
                    pytorch_model)

        logger.info("Quantized Model:\n{}".format(pytorch_model))
    else:
        pytorch_model = fuse_utils.fuse_model(pytorch_model)
        logger.info("Fused Model:\n{}".format(pytorch_model))
        if fuse_utils.count_bn_exist(pytorch_model) > 0:
            logger.warning("BN existed in pytorch model after fusing.")

    return export_predictor(cfg, pytorch_model, predictor_type, output_dir,
                            data_loader)
示例#12
0
def default_rcnn_prepare_for_quant(self, cfg):
    model = self
    model.qconfig = set_backend_and_create_qconfig(cfg,
                                                   is_train=model.training)
    if (hasattr(model, "roi_heads") and hasattr(model.roi_heads, "mask_head")
            and isinstance(model.roi_heads.mask_head, PointRendMaskHead)):
        model.roi_heads.mask_head.qconfig = None
    logger.info("Setup the model with qconfig:\n{}".format(model.qconfig))

    # Modify the model for eager mode
    if cfg.QUANTIZATION.EAGER_MODE:
        model = _apply_eager_mode_quant(cfg, model)
        model = fuse_utils.fuse_model(
            model,
            is_qat=cfg.QUANTIZATION.QAT.ENABLED,
            inplace=True,
        )
    else:
        _fx_quant_prepare(model, cfg)

    return model
示例#13
0
    def test_fuse_model_customized(self):
        class Module(torch.nn.Module):
            def __init__(self):
                super().__init__()

                self.conv = torch.nn.Conv2d(3, 4, 1)
                self.bn = torch.nn.BatchNorm2d(4, 4)
                self.conv2 = torch.nn.Conv2d(4, 2, 1)
                self.bn2 = torch.nn.BatchNorm2d(2, 2)
                self.relu2 = torch.nn.ReLU()

            def forward(self, x):
                x = self.conv(x)
                x = self.bn(x)
                x = self.conv2(x)
                x = self.bn2(x)
                x = self.relu2(x)
                return x

        @fuse_utils.FUSE_LIST_GETTER.register(Module)
        def _get_fuser_name_cbr(
            module: torch.nn.Module,
            supported_types: typing.Dict[str, typing.List[torch.nn.Module]],
        ):
            return [["conv", "bn"], ["conv2", "bn2", "relu2"]]

        model = Module().eval()

        fused_model = fuse_utils.fuse_model(model, inplace=False)
        print(model)
        print(fused_model)

        self.assertTrue(_find_modules(model, torch.nn.BatchNorm2d))
        self.assertFalse(_find_modules(fused_model, torch.nn.BatchNorm2d))
        self.assertTrue(_find_modules(fused_model, torch.nn.ReLU))

        input_size = [1, 3, 4, 4]
        run_and_compare(model, fused_model, input_size)
示例#14
0
def convert_predictor(
    cfg,
    pytorch_model,
    predictor_type,
    data_loader,
):
    if "int8" in predictor_type:
        if not cfg.QUANTIZATION.QAT.ENABLED:
            logger.info(
                "The model is not quantized during training, running post"
                " training quantization ...")

            pytorch_model = post_training_quantize(cfg, pytorch_model,
                                                   data_loader)
            # only check bn exists in ptq as qat still has bn inside fused ops
            if fuse_utils.check_bn_exist(pytorch_model):
                logger.warn(
                    "Post training quantized model has bn inside fused ops")
        logger.info(
            f"Converting quantized model {cfg.QUANTIZATION.BACKEND}...")

        if hasattr(pytorch_model, "prepare_for_quant_convert"):
            pytorch_model = pytorch_model.prepare_for_quant_convert(cfg)
        else:
            # TODO(T93870381): move this to a default function
            if cfg.QUANTIZATION.EAGER_MODE:
                pytorch_model = convert(pytorch_model, inplace=False)
            else:  # FX graph mode quantization
                pytorch_model = convert_fx(pytorch_model)

        logger.info("Quantized Model:\n{}".format(pytorch_model))
    else:
        pytorch_model = fuse_utils.fuse_model(pytorch_model)
        logger.info("Fused Model:\n{}".format(pytorch_model))
        if fuse_utils.count_bn_exist(pytorch_model) > 0:
            logger.warning("BN existed in pytorch model after fusing.")
    return pytorch_model