def __init__(self): super(Net, self).__init__() #self.conv1 = torch.nn.Conv2d(1, 32, (5, 5), padding=(2, 2), bias=True) # 换成对应的 Quantize 系列的 API self.conv1 = qnn.QuantConv2d(1, 32, (5, 5), padding=(2, 2), bias=True) #self.conv2 = torch.nn.Conv2d(32, 64, (5, 5), padding=(2, 2), bias=True) self.conv2 = qnn.QuantConv2d(32, 64, (5, 5), padding=(2, 2), bias=True) #self.fc1 = torch.nn.Linear(64 * 7 * 7, 1024, bias=True) self.fc1 = qnn.QuantLinear(64 * 7 * 7, 1024, bias=True) #self.fc2 = torch.nn.Linear(1024, 10, bias=True) self.fc2 = qnn.QuantLinear(1024, 10, bias=True)
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1, quantize: bool = False) -> nn.Conv2d: """3x3 convolution with padding""" if quantize: return quant_nn.QuantConv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation) else: return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)
def test_initialize_deactivate(self): no_replace_list = ["Linear"] custom_quant_modules = [(torch.nn, "Linear", quant_nn.QuantLinear)] quant_modules.initialize(no_replace_list, custom_quant_modules) assert (type(quant_nn.QuantLinear(16, 256, 3)) == type( torch.nn.Linear(16, 256, 3))) assert (type(quant_nn.QuantConv2d(16, 256, 3)) == type( torch.nn.Conv2d(16, 256, 3))) quant_modules.deactivate()
def test_simple_default_args(self): replacement_helper = QuantModuleReplacementHelper() replacement_helper.prepare_state() replacement_helper.apply_quant_modules() # Linear module should not be replaced with its quantized version assert (type(quant_nn.QuantLinear(16, 256, 3)) == type( torch.nn.Linear(16, 256, 3))) assert (type(quant_nn.QuantConv2d(16, 256, 3)) == type( torch.nn.Conv2d(16, 256, 3))) replacement_helper.restore_float_modules()
def test_with_no_replace_list(self): no_replace_list = ["Linear"] custom_quant_modules = None replacement_helper = QuantModuleReplacementHelper() replacement_helper.prepare_state(no_replace_list, custom_quant_modules) replacement_helper.apply_quant_modules() # Linear module should not be replaced with its quantized version assert (type(quant_nn.QuantLinear(16, 256, 3)) != type( torch.nn.Linear(16, 256, 3))) assert (type(quant_nn.QuantConv2d(16, 256, 3)) == type( torch.nn.Conv2d(16, 256, 3))) replacement_helper.restore_float_modules()
def test_with_custom_quant_modules(self): no_replace_list = ["Linear"] custom_quant_modules = [(torch.nn, "Linear", quant_nn.QuantLinear)] replacement_helper = QuantModuleReplacementHelper() replacement_helper.prepare_state(no_replace_list, custom_quant_modules) replacement_helper.apply_quant_modules() # Although no replace list indicates Linear module should not be replaced with its # quantized version, since the custom_quant_modules still contains the Linear module's # mapping, it will replaced. assert (type(quant_nn.QuantLinear(16, 256, 3)) == type( torch.nn.Linear(16, 256, 3))) assert (type(quant_nn.QuantConv2d(16, 256, 3)) == type( torch.nn.Conv2d(16, 256, 3))) replacement_helper.restore_float_modules()
def conv1x1(in_planes: int, out_planes: int, stride: int = 1, quantize: bool = False) -> nn.Conv2d: """1x1 convolution""" if quantize: return quant_nn.QuantConv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) else: return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
def __init__( self, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], quantize: bool = False, num_classes: int = 1000, zero_init_residual: bool = False, groups: int = 1, width_per_group: int = 64, replace_stride_with_dilation: Optional[List[bool]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None) -> None: super(ResNet, self).__init__() self._quantize = quantize if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer self.inplanes = 64 self.dilation = 1 if replace_stride_with_dilation is None: # each element in the tuple indicates if we should replace # the 2x2 stride with a dilated convolution instead replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError("replace_stride_with_dilation should be None " "or a 3-element tuple, got {}".format( replace_stride_with_dilation)) self.groups = groups self.base_width = width_per_group if quantize: self.conv1 = quant_nn.QuantConv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) else: self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0], quantize=quantize) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0], quantize=quantize) self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1], quantize=quantize) self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2], quantize=quantize) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) if quantize: self.fc = quant_nn.QuantLinear(512 * block.expansion, num_classes) else: self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]