Ejemplo n.º 1
0
def convert_torch_script(
    model, inputs, fuse_bn=True, verify_output=True, use_get_traceable=False
):
    assert isinstance(inputs, (tuple, list)), f"Invalid input types {inputs}"
    if verify_output:
        print("Run pytorch model")
        with torch.no_grad():
            output_before = model(*inputs)

    if fuse_bn:
        print("Fusing bn...")
        fused_model = fuse_utils.fuse_model(model)
        if fuse_utils.check_bn_exist(fused_model):
            print(f"WARNING: BN existed after fusing, {fused_model}")
    else:
        fused_model = copy.deepcopy(model)

    for x in fused_model.parameters():
        x.requires_grad = False

    if use_get_traceable:
        print("Get traceable model...")
        fused_model = ju.get_traceable_model(fused_model)

    print("Start tracing...")
    with torch.no_grad():
        traced_model = torch.jit.trace(fused_model, inputs, strict=False)
    # print(f"Traced model {traced_model}")
    # print(f"Traced model {traced_model.code}")

    # print("Optimizing traced model...")
    # traced_model = optimize_for_mobile(traced_model)
    print("Generating traced model lints...")
    print(generate_mobile_module_lints(traced_model))

    print("Run traced model")
    with torch.no_grad():
        outputs = traced_model(*inputs)

    if verify_output:
        paired_outputs = iu.create_pair(output_before, outputs)
        for x in iu.recursive_iterate(paired_outputs, iter_types=torch.Tensor):
            np.testing.assert_allclose(
                x.lhs.detach(), x.rhs.detach(), rtol=0, atol=1e-4
            )

    return traced_model, outputs
Ejemplo n.º 2
0
def convert_and_export_predictor(cfg, pytorch_model, predictor_type,
                                 output_dir, data_loader):
    """
    Entry point for convert and export model. This involves two steps:
        - convert: converting the given `pytorch_model` to another format, currently
            mainly for quantizing the model.
        - export: exporting the converted `pytorch_model` to predictor. This step
            should not alter the behaviour of model.
    """
    if "int8" in predictor_type:
        if not cfg.QUANTIZATION.QAT.ENABLED:
            logger.info(
                "The model is not quantized during training, running post"
                " training quantization ...")
            pytorch_model = post_training_quantize(cfg, pytorch_model,
                                                   data_loader)
            # only check bn exists in ptq as qat still has bn inside fused ops
            assert not fuse_utils.check_bn_exist(pytorch_model)
        logger.info(
            f"Converting quantized model {cfg.QUANTIZATION.BACKEND}...")
        if cfg.QUANTIZATION.EAGER_MODE:
            # TODO(future diff): move this logic to prepare_for_quant_convert
            pytorch_model = torch.quantization.convert(pytorch_model,
                                                       inplace=False)
        else:  # FX graph mode quantization
            if hasattr(pytorch_model, "prepare_for_quant_convert"):
                pytorch_model = pytorch_model.prepare_for_quant_convert(cfg)
            else:
                # TODO(future diff): move this to a default function
                pytorch_model = torch.quantization.quantize_fx.convert_fx(
                    pytorch_model)

        logger.info("Quantized Model:\n{}".format(pytorch_model))
    else:
        pytorch_model = fuse_utils.fuse_model(pytorch_model)
        logger.info("Fused Model:\n{}".format(pytorch_model))
        if fuse_utils.count_bn_exist(pytorch_model) > 0:
            logger.warning("BN existed in pytorch model after fusing.")

    return export_predictor(cfg, pytorch_model, predictor_type, output_dir,
                            data_loader)
Ejemplo n.º 3
0
def convert_predictor(
    cfg,
    pytorch_model,
    predictor_type,
    data_loader,
):
    if "int8" in predictor_type:
        if not cfg.QUANTIZATION.QAT.ENABLED:
            logger.info(
                "The model is not quantized during training, running post"
                " training quantization ...")

            pytorch_model = post_training_quantize(cfg, pytorch_model,
                                                   data_loader)
            # only check bn exists in ptq as qat still has bn inside fused ops
            if fuse_utils.check_bn_exist(pytorch_model):
                logger.warn(
                    "Post training quantized model has bn inside fused ops")
        logger.info(
            f"Converting quantized model {cfg.QUANTIZATION.BACKEND}...")

        if hasattr(pytorch_model, "prepare_for_quant_convert"):
            pytorch_model = pytorch_model.prepare_for_quant_convert(cfg)
        else:
            # TODO(T93870381): move this to a default function
            if cfg.QUANTIZATION.EAGER_MODE:
                pytorch_model = convert(pytorch_model, inplace=False)
            else:  # FX graph mode quantization
                pytorch_model = convert_fx(pytorch_model)

        logger.info("Quantized Model:\n{}".format(pytorch_model))
    else:
        pytorch_model = fuse_utils.fuse_model(pytorch_model)
        logger.info("Fused Model:\n{}".format(pytorch_model))
        if fuse_utils.count_bn_exist(pytorch_model) > 0:
            logger.warning("BN existed in pytorch model after fusing.")
    return pytorch_model