def from_float(cls, mod): if hasattr(mod, "weight_fake_quant"): # assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \ # ".from_float only works for " + cls.__QAT_MODULE.__name__ if type(mod) == cls._NNIQAT_CONV_BN_MODULE: mod.weight, mod.bias = fuse_conv_bn_weights( mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var, mod.bn.eps, mod.bn.weight, mod.bn.bias) assert hasattr(mod, "activation_post_process"), \ "Input QAT module must have observer attached" weight_post_process = mod.weight_fake_quant activation_post_process = mod.activation_post_process else: assert type(mod) == cls._FLOAT_MODULE, \ " nnq." + cls.__name__ + ".from_float only works for " + \ cls._FLOAT_MODULE.__name__ + " but got:" + str(type(mod)) assert hasattr(mod, "qconfig"), \ "Input float module must have qconfig defined." activation_post_process = None if not hasattr( mod, "activation_post_process") else mod.activation_post_process if type(mod) == cls._NNI_CONV_RELU_MODULE: mod = mod[0] weight_post_process = mod.qconfig.weight() return cls.get_qconv(mod, activation_post_process, weight_post_process)
def to_float(self): cls = type(self) conv = cls._FLOAT_CONV_MODULE( # type: ignore[attr-defined] self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding, self.dilation, self.groups, self.bias is not None, self.padding_mode) conv.weight = torch.nn.Parameter(self.weight.detach()) if self.bias is not None: conv.bias = torch.nn.Parameter(self.bias.detach()) if cls._FLOAT_BN_MODULE: # type: ignore[attr-defined] # fuse bn into conv conv.weight, conv.bias = fuse_conv_bn_weights( conv.weight, conv.bias, self.bn.running_mean, self.bn.running_var, self.bn.eps, self.bn.weight, self.bn.bias) if cls._FLOAT_RELU_MODULE: # type: ignore[attr-defined] modules = [] modules.append(conv) relu = cls._FLOAT_RELU_MODULE() # type: ignore[attr-defined] modules.append(relu) conv_relu = cls._FUSED_FLOAT_MODULE( *modules) # type: ignore[attr-defined] conv_relu.train(self.training) return conv_relu else: conv.train(self.training) return conv
def from_float(cls, mod): r"""Creates a quantized module from a float module or qparams_dict. Args: mod (Module): a float module, either produced by torch.quantization utilities or provided by the user """ if hasattr(mod, 'weight_fake_quant'): # assert type(mod) == cls.__QAT_MODULE, ' nnq.' + cls.__name__ + \ # '.from_float only works for ' + cls.__QAT_MODULE.__name__ if type(mod) == nniqat.ConvBn2d: mod.weight, mod.bias = fuse_conv_bn_weights( mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var, mod.bn.eps, mod.bn.weight, mod.bn.bias) assert hasattr(mod, 'activation_post_process'), \ 'Input QAT module must have observer attached' weight_post_process = mod.weight_fake_quant activation_post_process = mod.activation_post_process else: assert type(mod) == cls._FLOAT_MODULE, \ ' nnq.' + cls.__name__ + '.from_float only works for ' + \ cls._FLOAT_MODULE.__name__ assert hasattr(mod, 'qconfig'), \ 'Input float module must have qconfig defined.' # workaround for sequential, ConvReLU2d should probably # inherit from Conv2d instead if type(mod) == nni.ConvReLU2d: activation_post_process = mod[1].activation_post_process mod = mod[0] else: activation_post_process = mod.activation_post_process weight_post_process = mod.qconfig.weight() return cls.get_qconv(mod, activation_post_process, weight_post_process)
def from_float(cls, mod): r"""Creates a quantized module from a float module or qparams_dict. Args: mod (Module): a float module, either produced by torch.quantization utilities or provided by the user """ if hasattr(mod, 'weight_fake_quant'): # assert type(mod) == cls.__QAT_MODULE, ' nnq.' + cls.__name__ + '.from_float only works for ' + \ # cls.__QAT_MODULE.__name__ if type(mod) == nniqat.ConvBn2d: mod.weight, mod.bias = \ fuse_conv_bn_weights(mod.weight, mod.bias, mod.running_mean, mod.running_var, mod.eps, mod.gamma, mod.beta) assert hasattr( mod, 'observer'), 'Input QAT module must have observer attached' weight_observer = mod.weight_fake_quant activation_observer = mod.observer else: assert type(mod) == cls._FLOAT_MODULE, ' nnq.' + cls.__name__ + '.from_float only works for ' + \ cls._FLOAT_MODULE.__name__ assert hasattr( mod, 'qconfig'), 'Input float module must have qconfig defined' # workaround for sequential, ConvReLU2d should probably # inherit from Conv2d instead if type(mod) == nni.ConvReLU2d: activation_observer = mod[1].observer mod = mod[0] else: activation_observer = mod.observer weight_observer = mod.qconfig.weight() weight_observer(mod.weight) act_scale, act_zp = activation_observer.calculate_qparams() assert weight_observer.dtype == torch.qint8, 'Weight observer must have a dtype of qint8' wt_scale, wt_zp = weight_observer.calculate_qparams() # Scale bias to activation_scale/2^16, this quantizes bias # to about 24 bits of precision bias_scale = float(act_scale / (2**16)) qweight = torch.quantize_linear(mod.weight.float(), float(wt_scale), int(wt_zp), torch.qint8) qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, mod.stride, mod.padding, mod.dilation, mod.groups, mod.bias is not None, mod.padding_mode) qconv.set_weight(qweight) if mod.bias is not None: qbias = torch.quantize_linear(mod.bias.float(), bias_scale, 0, torch.qint32) else: qbias = None qconv.bias = qbias qconv.scale = float(act_scale) qconv.zero_point = int(act_zp) return qconv
def from_float(cls, mod): if type(mod) == torch.nn.intrinsic.qat.ConvBnReLU3d: mod.weight, mod.bias = fuse_conv_bn_weights( mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var, mod.bn.eps, mod.bn.weight, mod.bn.bias, ) return super(ConvReLU3d, cls).from_float(mod)
def from_float(cls, mod): r"""Creates a quantized module from a float module or qparams_dict. Args: mod (Module): a float module, either produced by torch.quantization utilities or provided by the user """ if hasattr(mod, 'weight_fake_quant'): # assert type(mod) == cls.__QAT_MODULE, ' nnq.' + cls.__name__ + \ # '.from_float only works for ' + cls.__QAT_MODULE.__name__ if type(mod) == nniqat.ConvBn2d: mod.weight, mod.bias = fuse_conv_bn_weights( mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var, mod.bn.eps, mod.bn.weight, mod.bn.bias) assert hasattr(mod, 'activation_post_process'), \ 'Input QAT module must have observer attached' weight_post_process = mod.weight_fake_quant activation_post_process = mod.activation_post_process else: assert type(mod) == cls._FLOAT_MODULE, \ ' nnq.' + cls.__name__ + '.from_float only works for ' + \ cls._FLOAT_MODULE.__name__ assert hasattr(mod, 'qconfig'), \ 'Input float module must have qconfig defined.' # workaround for sequential, ConvReLU2d should probably # inherit from Conv2d instead if type(mod) == nni.ConvReLU2d: activation_post_process = mod[1].activation_post_process mod = mod[0] else: activation_post_process = mod.activation_post_process weight_post_process = mod.qconfig.weight() weight_post_process(mod.weight) act_scale, act_zp = activation_post_process.calculate_qparams() assert weight_post_process.dtype == torch.qint8, \ 'Weight observer must have a dtype of qint8' qweight = _quantize_weight(mod.weight.float(), weight_post_process) qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, mod.stride, mod.padding, mod.dilation, mod.groups, mod.bias is not None, mod.padding_mode) qconv.set_weight_bias(qweight, mod.bias) qconv.scale = float(act_scale) qconv.zero_point = int(act_zp) return qconv
def quantize(self, quantizer, node, load_arg): mod = self.conv weight, bias = mod.weight, mod.bias if self.bn_node is not None: weight, bias = fuse_conv_bn_weights(weight, bias, self.bn.running_mean, self.bn.running_var, self.bn.eps, self.bn.weight, self.bn.bias) min_val, max_val = float(weight.min()), float(weight.max()) act_scale, act_zp = self.scale_zeropoint() weight_scale, weight_zp = _minmax_scale_zeropoint(min_val, max_val) qweight = torch.quantize_per_tensor(weight, weight_scale, weight_zp, torch.qint8) ctor = torch.nn.intrinsic.quantized.ConvReLU2d if self.relu_node is not None else torch.nn.quantized.Conv2d qconv = ctor(mod.in_channels, mod.out_channels, mod.kernel_size, mod.stride, mod.padding, mod.dilation, mod.groups, mod.bias is not None, mod.padding_mode) qconv.set_weight_bias(qweight, bias) qconv.scale = float(act_scale) qconv.zero_point = int(act_zp) parent_name, name = _parent_name(self.conv_node.target) setattr(quantizer.modules[parent_name], name, qconv) if self.bn_node is not None: parent_bn, bn_name = _parent_name(self.bn_node.target) # we can't just delete this because submodules's forwards (which are not longer use) # try to call it, so replace with something that does nothing. setattr(quantizer.modules[parent_name], bn_name, IdentityModule()) return quantizer.quantized_graph.create_node( 'call_module', self.conv_node.target, (load_arg(self.conv_node.args[0]), ), {})
def from_float(cls, mod): if type(mod) == torch.nn._intrinsic.qat.ConvBnReLU2d: mod.weight, mod.bias = \ fuse_conv_bn_weights(mod.weight, mod.bias, mod.running_mean, mod.running_var, mod.eps, mod.gamma, mod.beta) return super(ConvReLU2d, cls).from_float(mod)
def traceHook(self, module, input: Union[T, Tuple[T, T]], output: T): """ Pytorch NN module forward hook. Used to intercept NN layer execution order and dependency. This function will be called by PyTorch during inference. TODO: Add support for transposed convolution TODO: Add support fof MaxPool2d properly :param module: The PyTorch NN module to be hooked. :param input: Input tensor. Can be a tuple if the module has multiple inputs. :param output: :return: The modified output. """ # Modify the output # output[0]: layer id of the layer # output[1]: INT8 activation precision print("Tracing module {}. Type; {}. {}".format(self.layerID, type(module), module)) # Extract the information: input precision(s), input id inputIDs = [] inputPrecisions = [] inputChannels = [] inputGroupsSeenbySource = [] inputHeights = [] inputWidths = [] input0FracBits = None outputPrecisionScale = None isAfterInput = False if isinstance(module, QuantStub) is False: if isinstance(input, T): nElements = input.numel() idx = int(input.view(nElements)[self.ID_IDX].item()) # print('Input idx: {}'.format(idx)) inputIDs.append(idx) # inputPrecision = input.view(nElements)[self.PRECISION_IDX].item() inputPrecision = torch.pow( torch.tensor(0.5, dtype=torch.float), torch.tensor(self.layerList[idx].outputFracBits, dtype=torch.float)) # print('Input precision: {}'.format(inputPrecision)) inputPrecisions.append(inputPrecision) inputChannels.append(self.layerList[idx].outputChannels) inputGroupsSeenbySource.append( self.layerList[idx].outputCurrentNumGroups) inputHeights.append(self.layerList[idx].outputHeight) inputWidths.append(self.layerList[idx].outputWidth) if (self.layerList[idx].operationType == 'quantstub'): isAfterInput = True elif isinstance(input, tuple): for tensor in input: nElements = tensor.numel() # print('Input: {}'.format(tensor.view(nElements)[0:2])) idx = int(tensor.view(nElements)[self.ID_IDX].item()) # print('Input idx: {}'.format(idx)) inputIDs.append(idx) # inputPrecision = tensor.view(nElements)[self.PRECISION_IDX].item() inputPrecision = torch.pow( torch.tensor(0.5, dtype=torch.float), torch.tensor(self.layerList[idx].outputFracBits, dtype=torch.float)) # print('Input precision: {}'.format(inputPrecision)) inputPrecisions.append(inputPrecision) inputChannels.append(self.layerList[idx].outputChannels) inputGroupsSeenbySource.append( self.layerList[idx].outputCurrentNumGroups) inputHeights.append(self.layerList[idx].outputHeight) inputWidths.append(self.layerList[idx].outputWidth) if (self.layerList[idx].operationType == 'quantstub'): isAfterInput = True else: raise TypeError( 'The input argument is neither a tensor nor a tuple of tensors' ) input0FracBits = torch.round(torch.log2( 1.0 / inputPrecisions[0])).view(1)[0].item() outputPrecisionScale = inputPrecisions[0] # Determine output precision scales if hasattr(module, 'activation_post_process'): outputPrecisionScale = module.activation_post_process.scale.view(1) elif isinstance(module, cm.EltwiseAdd): outputPrecisionScale = module.quant.activation_post_process.scale.view( 1) else: if isinstance(module, (cm.MaxPool2dRelu, cm.AvgPool2dRelu)): # Number of fraction bits = log2(1/scale) # Number of interger bits = 8 - number of fraction bits # Number of fraction bits new = log2(1/scale_new) = number of fraction bits - 1 = log2(1/scale) - 1 # log2(scale) + 1 = log2(scale_new) # scale_new = 2 * scale # iprecision0 = inputPrecisions[0].view(1)[0].item() # iprecision1 = inputPrecisions[1].view(1)[0].item() # import math # if math.isclose(iprecision0, iprecision1): # outputPrecisionScale *= 2.0 # else: # outputPrecisionScale = torch.tensor(iprecision1) if iprecision1 > iprecision0 else torch.tensor(iprecision0) outputPrecisionScale = module.quant.activation_post_process.scale.view( 1) else: pass outputFracBits = int( torch.round(torch.log2(1.0 / outputPrecisionScale)).view(1)[0].item()) # Instantiate and insert a layer, register input/output adjacencies # If this is a convolution-based layer. Even the fused layer types are Conv2d thanks to inheritance newLayer = None outputChannels = output.size()[1] outputRelu = False # For the list that convolution layer-like types after qat is applied, # see # https://github.com/pytorch/pytorch/blob/20ac7362009dd8e0aca6e72fc9357773136a83b8/torch/quantization/quantization_mappings.py#L54 if isinstance( module, (nnqat.Linear, nnqat.Conv2d, nniqat.ConvBn2d, nniqat.ConvBnReLU2d, nniqat.ConvReLU2d, nniqat.LinearReLU, cqat_modules.Linear, cqat_modules.Conv2d, cqat_modules.ConvBn2d, cqat_modules.ConvBnReLU2d, cqat_modules.ConvReLU2d, cqat_modules.LinearReLU)): if isinstance(module, (nniqat.ConvReLU2d, nniqat.LinearReLU, nniqat.ConvBnReLU2d, cqat_modules.ConvBnReLU2d, cqat_modules.ConvReLU2d, cqat_modules.LinearReLU)): outputRelu = True # Determine padding and kernel size padding = 0 kernelSize = inputWidths[0] kernelStride = kernelSize groups = 1 if isinstance( module, (nnqat.Conv2d, nniqat.ConvBn2d, nniqat.ConvBnReLU2d, nniqat.ConvReLU2d, cqat_modules.Conv2d, cqat_modules.ConvBn2d, cqat_modules.ConvBnReLU2d, cqat_modules.ConvReLU2d)): # Extract the padding for convolution. # Assume that the horizontal and vertical paddings are the same padding = module.padding[0] kernelSize = module.kernel_size[0] kernelStride = module.stride[0] groups = module.groups weight = module.weight bias = None if module.bias is not None: bias = module.bias # Perform batchnorm folding and update the quantization parameters if necessary if self.foldBN and isinstance( module, (nniqat.ConvBn2d, nniqat.ConvBnReLU2d, cqat_modules.ConvBn2d, cqat_modules.ConvBnReLU2d)): """ Assumptions: - This is a fused model - When the fusion occured, the model was not in eval mode, so its bn parameters are not folded at this moment - Warning: If bn folding have occured, this the following folding will mess up the weights and bias - Reason: PyTorch's BN folding function only change conv weights and biases, but does affect BN parameters """ # PyTorch v1.5.0 # weight, bias = fuse_conv_bn_weights( # module.weight, module.bias, module.running_mean, module.running_var, # module.eps, module.gamma, module.beta) # PyTorch v1.6.0 weight, bias = fuse_conv_bn_weights(module.weight, module.bias, module.bn.running_mean, module.bn.running_var, module.bn.eps, module.bn.weight, module.bn.bias) # Determine weight frac bits # The quantization observer for weight of fused conv_bn already monitors the folded weights # see: https://github.com/pytorch/pytorch/blob/7f73f1d591afba823daa4a99a939217fb54d7688/torch/nn/intrinsic/qat/modules/conv_fused.py#L113 weight_quantizer = module.weight_fake_quant # weight_post_process.enable_observer() # weight_post_process(weight) # weight_post_process.disable_observer() # weight_quantizer = custom_quant.RoundedMinMaxObserver() # weight_quantizer.forward(weight) weightPrecisionScale, _ = weight_quantizer.calculate_qparams() weightPrecisionScale = weightPrecisionScale.view(1) weightFracBits = int( torch.ceil(torch.log2(1.0 / weightPrecisionScale)).view(1)[0].item()) # Just use the quantizer to quantize the weights before exporting weight = weight_quantizer.forward(weight) hasBias = False if bias is None else True # Quantize bias # if hasBias: # bias = cqat_modules.quantize_bias(module, bias) # Determine the SpW parameters flagFoundSpWInfo = False pruneRangeInCluster = None pruneCluster = None sparsity = None for _, hook in module._forward_pre_hooks.items(): if isinstance(hook, cm_prune.balancedPruningMethod): flagFoundSpWInfo = True pruneRangeInCluster = hook.pruneRangeInCluster sparsity = hook.sparsity pruneCluster = hook.clusterSize if not flagFoundSpWInfo: print( "Using default sparsity pruning parameters for module {}". format(module)) pruneRangeInCluster = self.defaultPruneRangeInCluster sparsity = 0.0 pruneCluster = self.defaultPruneCluster needToPermuteWeight = True if isinstance( module, (nnqat.Linear, nniqat.LinearReLU, cqat_modules.Linear)): needToPermuteWeight = False newLayer = ConvInfo( outputFracBits=int(outputFracBits), outputChannels=outputChannels, outputRelu=outputRelu, inputFracBits=int(input0FracBits), inputHeight=inputHeights[0], inputWidth=inputWidths[0], inputBorderPadding=padding, inputTransConvPadding=0, inputChannels=inputChannels[0], inputGroupsSeenBySource=inputGroupsSeenbySource[0], weightFracBits=int(weightFracBits), kernelSize=kernelSize, kernelStride=kernelStride, hasBias=hasBias, pruneRangeInCluster=pruneRangeInCluster, pruneClusterSize=pruneCluster, sparsity=sparsity, channelGroups=groups, layerID=self.layerID, isAfterInput=isAfterInput, needToPermuteWeight=needToPermuteWeight) # Extract parameters # Extract weights newLayer.weightParameterFileStartPosition = self.parameterCount self.parameterCount += weight.numel() self.parameters.append(weight) self.parameterKeys.append(str(self.layerID) + '_weight') newLayer.biasParameterFileStartPosition = self.parameterCount self.parameterCount += outputChannels self.parameterKeys.append(str(self.layerID) + '_bias') if hasBias is False: bias = torch.zeros([outputChannels]) self.parameters.append(bias) # if this is an average pooling layer elif isinstance(module, cm.AvgPool2dRelu): #Average pooling layer should be seen as a special case of depth-wise convolution layer padding: int = 0 if hasattr(module.padding, '__getitem__'): padding = module.padding[0] else: padding = module.padding newLayer = AvgPoolInfo( outputFracBits=outputFracBits, outputRelu=False, inputFracBits=int(input0FracBits), inputHeight=inputHeights[0], inputWidth=inputWidths[0], inputBorderPadding=padding, inputChannels=inputChannels[0], inputGroupsSeenBySource=inputGroupsSeenbySource[0], kernelSize=module.kernel_size, kernelStride=module.stride, divisor=module.divisor_override, layerID=self.layerID) elif isinstance(module, cm.MaxPool2dRelu): padding: int = 0 if hasattr(module.padding, '__getitem__'): padding = module.padding[0] else: padding = module.padding newLayer = MaxPoolInfo( outputFracBits=outputFracBits, outputRelu=module.relu, inputFracBits=int(input0FracBits), inputHeight=inputHeights[0], inputWidth=inputWidths[0], inputBorderPadding=padding, inputChannels=inputChannels[0], inputGroupsSeenBySource=inputGroupsSeenbySource[0], kernelSize=module.kernel_size, kernelStride=module.stride, layerID=self.layerID) elif isinstance(module, cm.EltwiseAdd): assert inputHeights[0] == inputHeights[ 1], "Input heights do not match for eltwise-add operation" assert inputWidths[0] == inputWidths[ 1], "Input widths do not match for eltwise-add operation" assert inputChannels[0] == inputChannels[ 1], "Input channels do not match for eltwise-add operation" input1FracBits = int( torch.round(torch.log2(1.0 / inputPrecisions[1])).view(1)[0].item()) newLayer = EltAddInfo( outputFracBits=outputFracBits, outputRelu=module.relu, inputHeight=inputHeights[0], inputWidth=inputWidths[0], inputChannels=inputChannels[0], inputLeftFracBits=int(input0FracBits), inputRightFracBits=int(input1FracBits), inputLeftGroupsSeenBySource=int(inputGroupsSeenbySource[0]), inputRightGroupsSeenBySource=int(inputGroupsSeenbySource[1]), layerID=self.layerID) elif isinstance(module, QuantStub): newLayer = QuantStubInfo(outputFracBits=outputFracBits, outputChannels=input[0].size()[1], outputHeight=input[0].size()[2], outputWidth=input[0].size()[3], layerID=self.layerID) elif isinstance(module, DeQuantStub): newLayer = DeQuantStubInfo( inputFracBits=int(input0FracBits), inputChannels=input[0].size()[1], inputGroupsSeenBySource=inputGroupsSeenbySource[0], inputHeight=input[0].size()[2] if len(input[0].size()) > 2 else 1, inputWidth=input[0].size()[3] if len(input[0].size()) > 2 else 1, layerID=self.layerID) else: raise TypeError( 'The tracer hook can only be applied to QuantStub, Conv2d types, Linear types, AvgPool2d types, ' 'AvgPool2d, and EltwiseAdd types. Input module type is {}'. format(type(module))) # Insert the layer and save the connections self.insertLayer(newLayer) if isinstance(module, QuantStub) is False: for idx in inputIDs: self.addForwardEdge(sourceLayerId=idx, sinkLayerId=self.layerID) self.addBackward(sourceLayerId=idx, sinkLayerId=self.layerID) # Propagate the layer id, and output precision to the next layer outputNumel = output.numel() output.view(outputNumel)[self.ID_IDX] = self.layerID self.layerID += 1 self.layerCount += 1