def test_fbnet_v2_scriptable(self): e6 = {"expansion": 6} bn_args = {"bn_args": {"name": "bn", "momentum": 0.003}} arch_def = { "blocks": [ # [op, c, s, n, ...] # stage 0 [("conv_k3", 4, 2, 1, bn_args)], # stage 1 [ ("ir_k3", 8, 2, 2, e6, bn_args), ("ir_k5", 8, 1, 1, e6, bn_args), ], ] } model = _build_model(arch_def, dim_in=3) model.eval() model = fuse_utils.fuse_model(model, inplace=False) print(f"Fused model {model}") # Make sure model can be traced script_model = torch.jit.script(model) data = torch.zeros(1, 3, 32, 32) model_output = model(data) script_output = script_model(data) self.assertEqual(model_output.norm(), script_output.norm()) with tempfile.TemporaryDirectory() as tmp_dir: fn = os.path.join(tmp_dir, "model_script.jit") torch.jit.save(script_model, fn) self.assertTrue(os.path.isfile(fn))
def test_fbnet_v2_scriptable_empty_batch(self): e6 = {"expansion": 6} bn_args = {"bn_args": {"name": "bn", "momentum": 0.003}} arch_def = { "blocks": [ # [op, c, s, n, ...] # stage 0 [("conv_k3", 4, 2, 1, bn_args)], # stage 1 [ ("ir_k3", 8, 2, 2, e6, bn_args), ("ir_k5", 8, 1, 1, e6, bn_args), ], ] } model = _build_model(arch_def, dim_in=3) model.eval() model = fuse_utils.fuse_model(model, inplace=False) # Make sure model can be traced script_model = torch.jit.script(model) # empty batch data = torch.zeros(0, 3, 32, 32) script_output = script_model(data) self.assertEqual(script_output.shape, torch.Size([0, 8, 8, 8]))
def test_fuse_model_swish(self): e6 = {"expansion": 6} dw_skip_bnrelu = {"dw_skip_bnrelu": True} bn_args = {"bn_args": {"name": "bn", "momentum": 0.003}} arch_def = { "blocks": [ # [c, s, n, ...] # stage 0 [("conv_k3", 4, 2, 1, bn_args, { "relu_args": "swish" })], # stage 1 [ ("ir_k3", 8, 2, 2, e6, dw_skip_bnrelu, bn_args), ("ir_k5_sehsig", 8, 1, 1, e6, bn_args), ], ] } model = _build_model(arch_def, dim_in=3) fused_model = fuse_utils.fuse_model(model, inplace=False) print(model) print(fused_model) self.assertTrue(_find_modules(model, torch.nn.BatchNorm2d)) self.assertFalse(_find_modules(fused_model, torch.nn.BatchNorm2d)) self.assertTrue(_find_modules(fused_model, bb.Swish)) input_size = [2, 3, 8, 8] run_and_compare(model, fused_model, input_size)
def export_to_torchscript(args, task, model, inputs, output_base_dir, **kwargs): output_dir = os.path.join(output_base_dir, "torchscript") with torch.no_grad(): fused_model = fuse_utils.fuse_model(model, inplace=False) print("fused model {}".format(fused_model)) torch_script_path = trace_and_save_torchscript( fused_model, inputs, output_dir, use_get_traceable=bool(args.use_get_traceable), trace_type=args.trace_type, opt_for_mobile=args.opt_for_mobile, ) return torch_script_path
def default_prepare_for_quant(cfg, model): """ Default implementation of preparing a model for quantization. This function will be called to before training if QAT is enabled, or before calibration during PTQ if the model is not already quantized. NOTE: - This is the simplest implementation, most meta-arch needs its own version. - For eager model, user should make sure the returned model has Quant/DeQuant insert. This can be done by wrapping the model or defining the model with quant stubs. - QAT/PTQ can be determined by model.training. - Currently the input model can be changed inplace since we won't re-use the input model. - Currently this API doesn't include the final torch.ao.quantization.prepare(_qat) call since existing usecases don't have further steps after it. Args: model (nn.Module): a non-quantized model. cfg (CfgNode): config Return: nn.Module: a ready model for QAT training or PTQ calibration """ qconfig = set_backend_and_create_qconfig(cfg, is_train=model.training) if cfg.QUANTIZATION.EAGER_MODE: model = fuse_utils.fuse_model( model, is_qat=cfg.QUANTIZATION.QAT.ENABLED, inplace=True, ) model.qconfig = qconfig # TODO(future diff): move the torch.ao.quantization.prepare(...) call # here, to be consistent with the FX branch else: # FX graph mode quantization qconfig_dict = {"": qconfig} # TODO[quant-example-inputs]: needs follow up to change the api example_inputs = (torch.rand(1, 3, 3, 3), ) if model.training: model = prepare_qat_fx(model, qconfig_dict, example_inputs) else: model = prepare_fx(model, qconfig_dict, example_inputs) logger.info("Setup the model with qconfig:\n{}".format(qconfig)) return model
def convert_torch_script( model, inputs, fuse_bn=True, verify_output=True, use_get_traceable=False ): assert isinstance(inputs, (tuple, list)), f"Invalid input types {inputs}" if verify_output: print("Run pytorch model") with torch.no_grad(): output_before = model(*inputs) if fuse_bn: print("Fusing bn...") fused_model = fuse_utils.fuse_model(model) if fuse_utils.check_bn_exist(fused_model): print(f"WARNING: BN existed after fusing, {fused_model}") else: fused_model = copy.deepcopy(model) for x in fused_model.parameters(): x.requires_grad = False if use_get_traceable: print("Get traceable model...") fused_model = ju.get_traceable_model(fused_model) print("Start tracing...") with torch.no_grad(): traced_model = torch.jit.trace(fused_model, inputs, strict=False) # print(f"Traced model {traced_model}") # print(f"Traced model {traced_model.code}") # print("Optimizing traced model...") # traced_model = optimize_for_mobile(traced_model) print("Generating traced model lints...") print(generate_mobile_module_lints(traced_model)) print("Run traced model") with torch.no_grad(): outputs = traced_model(*inputs) if verify_output: paired_outputs = iu.create_pair(output_before, outputs) for x in iu.recursive_iterate(paired_outputs, iter_types=torch.Tensor): np.testing.assert_allclose( x.lhs.detach(), x.rhs.detach(), rtol=0, atol=1e-4 ) return traced_model, outputs
def d2_meta_arch_prepare_for_quant(self, cfg): model = self # Modify the model for eager mode if cfg.QUANTIZATION.EAGER_MODE: model = _apply_eager_mode_quant(cfg, model) model = fuse_utils.fuse_model(model, inplace=True) torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND model.qconfig = (torch.quantization.get_default_qat_qconfig( cfg.QUANTIZATION.BACKEND) if model.training else torch.quantization.get_default_qconfig( cfg.QUANTIZATION.BACKEND)) logger.info("Setup the model with qconfig:\n{}".format(model.qconfig)) return model
def test_post_quant(self): e6 = {"expansion": 6} dw_skip_bnrelu = {"dw_skip_bnrelu": True} bn_args = {"bn_args": {"name": "bn", "momentum": 0.003}} arch_def = { "blocks": [ # [op, c, s, n, ...] # stage 0 [("conv_k3", 4, 2, 1, bn_args)], # stage 1 [ ("ir_k3_sehsig", 8, 2, 2, e6, dw_skip_bnrelu, bn_args), ("ir_k5", 8, 1, 1, e6, bn_args), ], ] } model = _build_model(arch_def, dim_in=3) model = torch.quantization.QuantWrapper(model) model = fuse_utils.fuse_model(model, inplace=False) print(f"Fused model {model}") model.qconfig = torch.quantization.default_qconfig print(model.qconfig) torch.quantization.prepare(model, inplace=True) # calibration for _ in range(5): data = torch.rand([2, 3, 8, 8]) model(data) # Convert to quantized model quant_model = torch.quantization.convert(model, inplace=False) print(f"Quant model {quant_model}") # Run quantized model quant_output = quant_model(torch.rand([2, 3, 8, 8])) self.assertEqual(quant_output.shape, torch.Size([2, 8, 2, 2])) # Trace quantized model jit_model = torch.jit.trace(quant_model, data) jit_quant_output = jit_model(torch.rand([2, 3, 8, 8])) self.assertEqual(jit_quant_output.shape, torch.Size([2, 8, 2, 2]))
def test_fuse_model_with_reference(self): model = ModelWithRef() model.eval() self.assertEqual(id(model.cbr), id(model.cbr_list[0])) # keep the same reference after copying model_copy = copy.deepcopy(model) self.assertEqual(id(model_copy.cbr), id(model_copy.cbr_list[0])) fused_model = fuse_utils.fuse_model(model, inplace=False) self.assertEqual(id(fused_model.cbr), id(fused_model.cbr_list[0])) print(model) print(fused_model) self.assertTrue(_find_modules(model, torch.nn.BatchNorm2d)) self.assertFalse(_find_modules(fused_model, torch.nn.BatchNorm2d)) input_size = [2, 3, 8, 8] run_and_compare(model, fused_model, input_size)
def default_rcnn_prepare_for_quant(self, cfg): model = self torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND model.qconfig = (torch.quantization.get_default_qat_qconfig( cfg.QUANTIZATION.BACKEND) if model.training else torch.quantization.get_default_qconfig( cfg.QUANTIZATION.BACKEND)) if (hasattr(model, "roi_heads") and hasattr(model.roi_heads, "mask_head") and isinstance(model.roi_heads.mask_head, PointRendMaskHead)): model.roi_heads.mask_head.qconfig = None logger.info("Setup the model with qconfig:\n{}".format(model.qconfig)) # Modify the model for eager mode if cfg.QUANTIZATION.EAGER_MODE: model = _apply_eager_mode_quant(cfg, model) model = fuse_utils.fuse_model(model, inplace=True) else: _fx_quant_prepare(model, cfg) return model
def convert_and_export_predictor(cfg, pytorch_model, predictor_type, output_dir, data_loader): """ Entry point for convert and export model. This involves two steps: - convert: converting the given `pytorch_model` to another format, currently mainly for quantizing the model. - export: exporting the converted `pytorch_model` to predictor. This step should not alter the behaviour of model. """ if "int8" in predictor_type: if not cfg.QUANTIZATION.QAT.ENABLED: logger.info( "The model is not quantized during training, running post" " training quantization ...") pytorch_model = post_training_quantize(cfg, pytorch_model, data_loader) # only check bn exists in ptq as qat still has bn inside fused ops assert not fuse_utils.check_bn_exist(pytorch_model) logger.info( f"Converting quantized model {cfg.QUANTIZATION.BACKEND}...") if cfg.QUANTIZATION.EAGER_MODE: # TODO(future diff): move this logic to prepare_for_quant_convert pytorch_model = torch.quantization.convert(pytorch_model, inplace=False) else: # FX graph mode quantization if hasattr(pytorch_model, "prepare_for_quant_convert"): pytorch_model = pytorch_model.prepare_for_quant_convert(cfg) else: # TODO(future diff): move this to a default function pytorch_model = torch.quantization.quantize_fx.convert_fx( pytorch_model) logger.info("Quantized Model:\n{}".format(pytorch_model)) else: pytorch_model = fuse_utils.fuse_model(pytorch_model) logger.info("Fused Model:\n{}".format(pytorch_model)) if fuse_utils.count_bn_exist(pytorch_model) > 0: logger.warning("BN existed in pytorch model after fusing.") return export_predictor(cfg, pytorch_model, predictor_type, output_dir, data_loader)
def default_rcnn_prepare_for_quant(self, cfg): model = self model.qconfig = set_backend_and_create_qconfig(cfg, is_train=model.training) if (hasattr(model, "roi_heads") and hasattr(model.roi_heads, "mask_head") and isinstance(model.roi_heads.mask_head, PointRendMaskHead)): model.roi_heads.mask_head.qconfig = None logger.info("Setup the model with qconfig:\n{}".format(model.qconfig)) # Modify the model for eager mode if cfg.QUANTIZATION.EAGER_MODE: model = _apply_eager_mode_quant(cfg, model) model = fuse_utils.fuse_model( model, is_qat=cfg.QUANTIZATION.QAT.ENABLED, inplace=True, ) else: _fx_quant_prepare(model, cfg) return model
def test_fuse_model_customized(self): class Module(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(3, 4, 1) self.bn = torch.nn.BatchNorm2d(4, 4) self.conv2 = torch.nn.Conv2d(4, 2, 1) self.bn2 = torch.nn.BatchNorm2d(2, 2) self.relu2 = torch.nn.ReLU() def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.conv2(x) x = self.bn2(x) x = self.relu2(x) return x @fuse_utils.FUSE_LIST_GETTER.register(Module) def _get_fuser_name_cbr( module: torch.nn.Module, supported_types: typing.Dict[str, typing.List[torch.nn.Module]], ): return [["conv", "bn"], ["conv2", "bn2", "relu2"]] model = Module().eval() fused_model = fuse_utils.fuse_model(model, inplace=False) print(model) print(fused_model) self.assertTrue(_find_modules(model, torch.nn.BatchNorm2d)) self.assertFalse(_find_modules(fused_model, torch.nn.BatchNorm2d)) self.assertTrue(_find_modules(fused_model, torch.nn.ReLU)) input_size = [1, 3, 4, 4] run_and_compare(model, fused_model, input_size)
def convert_predictor( cfg, pytorch_model, predictor_type, data_loader, ): if "int8" in predictor_type: if not cfg.QUANTIZATION.QAT.ENABLED: logger.info( "The model is not quantized during training, running post" " training quantization ...") pytorch_model = post_training_quantize(cfg, pytorch_model, data_loader) # only check bn exists in ptq as qat still has bn inside fused ops if fuse_utils.check_bn_exist(pytorch_model): logger.warn( "Post training quantized model has bn inside fused ops") logger.info( f"Converting quantized model {cfg.QUANTIZATION.BACKEND}...") if hasattr(pytorch_model, "prepare_for_quant_convert"): pytorch_model = pytorch_model.prepare_for_quant_convert(cfg) else: # TODO(T93870381): move this to a default function if cfg.QUANTIZATION.EAGER_MODE: pytorch_model = convert(pytorch_model, inplace=False) else: # FX graph mode quantization pytorch_model = convert_fx(pytorch_model) logger.info("Quantized Model:\n{}".format(pytorch_model)) else: pytorch_model = fuse_utils.fuse_model(pytorch_model) logger.info("Fused Model:\n{}".format(pytorch_model)) if fuse_utils.count_bn_exist(pytorch_model) > 0: logger.warning("BN existed in pytorch model after fusing.") return pytorch_model