def load_student_model(config,
                       *,
                       qat=False,
                       preload_state_dict=False,
                       preload_dir='results'):
    if qat:
        student_model, student_model_name = get_qat_model(config,
                                                          pretrained=False)
    else:
        student_model, student_model_name = get_model(config, pretrained=False)
    # We don't preload from saved models, but rather from pretrained.
    if preload_state_dict:
        state_dict = make_student_save_name(preload_dir, config) + '.pt'
        state_dict = torch.load(state_dict)
        student_model.load_state_dict(state_dict)
    if qat:
        if hasattr(student_model, 'fuse_model'):
            student_model.fuse_model()
        student_model.qconfig = tq.get_default_qat_qconfig('fbgemm')
        tq.prepare_qat(student_model, inplace=True)

    return student_model, student_model_name
Example #2
0
    def __init__(
        self,
        start_step: int = 0,
        enable_observer: Tuple[int, Optional[int]] = (0, None),
        freeze_bn_step: Optional[int] = None,
        qconfig_dicts: Optional[
            Dict[str, Optional[Dict[str, Union[QConfig, QConfigDynamic]]]]
        ] = None,
        preserved_attrs: Optional[List[str]] = None,
        skip_conversion: bool = False,
    ) -> None:
        """
        Args:
            start_step: The training step at which QAT is enabled. The model is
                always mutated with the appropriate stubs, but they are disabled
                until the start of this training step.
                See FakeQuantizeBase.fake_quant_enabled
            enable_observer: The half-open interval [a, b) in steps during which the
                observers are enabled. See FakeQuantizeBase.observer_enabled. If
                b is None, the observer is never disabled once enabled.
            freeze_bn_step: If specified, the step at which we apply freeze the
                collection of batch normalization layer statistics for QAT.
            qconfig_dicts: If given, used for quantization of the model during training.
            preserved_attrs: If provided, a list of attributes to preserve across
                quantized modules. These are preserved only if they already exists.
        """
        if start_step < 0:
            raise ValueError(
                f"The starting step of QAT must be non-negative. Got {start_step}."
            )
        start_observer, end_observer = enable_observer
        if start_observer < 0:
            raise ValueError(
                f"The starting step for the observer must be non-negative. Got {start_observer}."
            )
        if end_observer and end_observer <= start_observer:
            raise ValueError(
                f"The observation interval must contain at least one step. Got [{start_step}, {end_observer})."
            )
        if freeze_bn_step and freeze_bn_step < 0:
            raise ValueError(
                f"The step at which batch norm layers are frozen must be non-negative. Got {freeze_bn_step}."
            )
        self.transforms: List[ModelTransform] = []
        if start_step > 0:
            self.transforms.extend(
                [
                    # Enabled by default, so the assumption for > 0 is that the
                    # user wants it disabled then enabled.
                    ModelTransform(
                        fn=torch.quantization.disable_fake_quant,
                        step=0,
                        message="Disable fake quant",
                    ),
                    ModelTransform(
                        fn=torch.quantization.enable_fake_quant,
                        step=start_step,
                        message="Enable fake quant to start QAT",
                    ),
                ]
            )
        if start_observer > 0:
            self.transforms.extend(
                # See comment for start_step above.
                [
                    ModelTransform(
                        fn=torch.quantization.disable_observer,
                        step=0,
                        message="Disable observer",
                    ),
                    ModelTransform(
                        fn=torch.quantization.enable_observer,
                        step=start_observer,
                        message="Start observer",
                    ),
                ]
            )
        if end_observer is not None:
            self.transforms.append(
                ModelTransform(
                    fn=torch.quantization.disable_observer,
                    step=end_observer,
                    message="End observer",
                )
            )
        if freeze_bn_step is not None:
            self.transforms.append(
                ModelTransform(
                    fn=torch.nn.intrinsic.qat.freeze_bn_stats,
                    step=freeze_bn_step,
                    message="Freeze BN",
                )
            )

        self.prepared: Optional[torch.nn.Module] = None
        self.preserved_attrs = set([] if preserved_attrs is None else preserved_attrs)
        if not qconfig_dicts:
            self.qconfig_dicts: QConfigDicts = {"": {"": get_default_qat_qconfig()}}
        else:
            self.qconfig_dicts: QConfigDicts = {
                key: value if value else {"": get_default_qat_qconfig()}
                for key, value in qconfig_dicts.items()
            }
        self.quantized: Optional[torch.nn.Module] = None
        self.skip_conversion = skip_conversion
Example #3
0
def test_tuple_lowered():
    # See the following discuss thread for details
    # https://discuss.tvm.apache.org/t/bug-frontend-pytorch-relay-ir-is-inconsistent-with-that-of-the-original-model/12010

    class ConvBnRelu(nn.Module):
        def __init__(self,
                     inp,
                     oup,
                     kernel_size=3,
                     stride=1,
                     padding=1,
                     bias=True,
                     groups=1):
            super(ConvBnRelu, self).__init__()
            if groups > 1:
                self.conv = nn.Conv2d(inp,
                                      inp,
                                      kernel_size,
                                      stride,
                                      padding,
                                      bias=bias,
                                      groups=groups)
                self.bn = nn.BatchNorm2d(inp)
            else:
                self.conv = nn.Conv2d(inp,
                                      oup,
                                      kernel_size,
                                      stride,
                                      padding,
                                      bias=bias,
                                      groups=groups)
                self.bn = nn.BatchNorm2d(oup)
            self.relu = nn.ReLU(inplace=True)

        def forward(self, inputs):
            x = self.conv(inputs)
            x = self.bn(x)
            x = self.relu(x)
            return x

    def conv_bn(inp, oup, stride=1, width_multiplier=1):
        return ConvBnRelu(inp,
                          oup,
                          kernel_size=3,
                          stride=stride,
                          padding=1,
                          bias=False)

    def conv_dw(inp, oup, stride, width_multiplier=1, padding=1):
        dw_block = nn.Sequential()
        depth_wise = ConvBnRelu(inp,
                                oup,
                                kernel_size=3,
                                stride=stride,
                                padding=padding,
                                bias=False,
                                groups=inp)
        point_wise = ConvBnRelu(inp,
                                oup,
                                kernel_size=1,
                                stride=1,
                                padding=0,
                                bias=False)

        dw_block.add_module("depth_wise", depth_wise)
        dw_block.add_module("point_wise", point_wise)

        return dw_block

    class Backbone(nn.Module):
        def __init__(self, width_multiplier=1):
            super(Backbone, self).__init__()
            self.width_multiplier = width_multiplier
            self.conv1 = conv_bn(3, 16, 2, self.width_multiplier)
            self.conv2 = conv_dw(16, 32, 1, self.width_multiplier)

        def forward(self, inputs):
            x1 = self.conv1(inputs)
            x2 = self.conv2(x1)
            return [x1, x2]

    class QuantizableBackbone(nn.Module):
        def __init__(self, inputsize=(128, 128)):
            super(QuantizableBackbone, self).__init__()
            self.quant = QuantStub()
            self.dequant = DeQuantStub()
            self.backbone = Backbone()

        def fuse_model(self):
            for idx, m in enumerate(self.modules()):
                if type(m) == ConvBnRelu:
                    torch.quantization.fuse_modules(m, ["conv", "bn", "relu"],
                                                    inplace=True)

        def forward(self, input):
            input = self.quant(input)
            y0, y1 = self.backbone(input)
            y0 = self.dequant(y0)
            y1 = self.dequant(y1)
            return y0, y1

    fp32_input = torch.randn(1, 3, 128, 128)
    model = QuantizableBackbone()
    model.train()
    model.fuse_model()
    model.qconfig = get_default_qat_qconfig("qnnpack")

    prepare_qat(model, inplace=True)

    model.eval()
    model(fp32_input)

    model_int8 = torch.quantization.convert(model, inplace=True)
    script_module = torch.jit.trace(model_int8, fp32_input).eval()

    input_infos = [("input", (fp32_input.shape, "float32"))]
    mod, _ = relay.frontend.from_pytorch(script_module, input_infos)
    output = mod["main"].body

    assert isinstance(output, relay.Tuple) and len(output) == 2
    dq1, dq2 = output
    assert str(dq1.op) == "qnn.dequantize" and str(dq2.op) == "qnn.dequantize"
    scale1 = dq1.args[1].data.numpy().item()
    scale2 = dq2.args[1].data.numpy().item()
    assert scale1 != scale2